解读 GenericFunc

解读 GenericFunc#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm
from tvm.target.generic_func import GenericFunc
/*!
 * \brief Generate the strategy of operators. This function is a generic
 * function and can be re-defined for different targets.
 *
 * The function signature of generic function is:
 *   OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
 *              const Type& out_type, const Target& target)
 */
using FTVMStrategy = GenericFunc;

这段代码定义了名为 FTVMStrategy 的类型别名,它表示通用函数。这个通用函数用于生成算子的策略。它的函数签名如下:

OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
           const Type& out_type, const Target& target)

其中,Attrs 表示属性集合,Array<Tensor> 表示输入张量数组,Type 表示输出类型,Target 表示目标平台。这个函数可以根据不同的目标平台进行重定义。

from tvm.relay.op.op import register_strategy

register_strategy??
Signature: register_strategy(op_name, fstrategy=None, level=10)
Source:   
def register_strategy(op_name, fstrategy=None, level=10):
    """Register strategy function for an op.

    Parameters
    ----------
    op_name : str
        The name of the op.

    fstrategy : function (attrs: Attrs, inputs: List[Tensor], out_type: Type,
                          target:Target) -> OpStrategy
        The strategy function. Need to be native GenericFunc.

    level : int
        The priority level
    """
    if not isinstance(fstrategy, GenericFunc):
        assert hasattr(fstrategy, "generic_func_node")
        fstrategy = fstrategy.generic_func_node
    return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type:      function
class GenericFuncNode;

/*!
 * \brief Generic function that can be specialized on a per-target basis.
 */
class GenericFunc : public ObjectRef {
 public:
  GenericFunc() {}
  explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}

  /*!
   * \brief Set the default function implementaiton.
   * \param value The default function
   * \param allow_override If true, this call may override a previously registered function. If
   * false, an error will be logged if the call would override a previously registered function.
   * \return reference to self.
   */
  TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
  /*!
   * \brief Register a specialized function
   * \param tags The tags for this specialization
   * \param value The specialized function
   * \param allow_override If true, this call may override previously registered tags. If false,
   * an error will be logged if the call would override previously registered tags.
   * \return reference to self.
   */
  TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
                                     const runtime::PackedFunc value, bool allow_override = false);
  /*!
   * \brief Call generic function by directly passing in unpacked format.
   * \param args Arguments to be passed.
   * \tparam Args arguments to be passed.
   *
   * \code
   *   // Example code on how to call generic function
   *   void CallGeneric(GenericFunc f) {
   *     // call like normal functions by pass in arguments
   *     // return value is automatically converted back
   *     int rvalue = f(1, 2.0);
   *   }
   * \endcode
   */
  template <typename... Args>
  inline runtime::TVMRetValue operator()(Args&&... args) const;
  /*!
   * \brief Invoke the relevant function for the current target context, set by set_target_context.
   * Arguments are passed in packed format.
   * \param args The arguments to pass to the function.
   * \param ret The return value
   */
  TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
  /*!
   * \brief Get the packed function specified for the current target context.
   */
  TVM_DLL PackedFunc GetPacked() const;
  /*!
   * \brief Find or register the GenericFunc instance corresponding to the give name
   * \param name The name of the registered GenericFunc
   * \return The GenericFunc instance
   */
  TVM_DLL static GenericFunc Get(const std::string& name);

  /*!
   * \brief Add a GenericFunc instance to the registry
   * \param func The GenericFunc instance
   * \param name The name of the registered GenericFunc
   */
  TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);

  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline GenericFuncNode* operator->();

  // declare container type
  using ContainerType = GenericFuncNode;

  // Internal class.
  struct Manager;

 private:
  friend struct Manager;
};

这段代码定义了一个名为GenericFunc的类,它继承自ObjectRef。这个类表示一个通用函数,可以针对每个目标平台进行特化。

该类中定义了一些成员函数和变量:

  • set_default函数用于设置默认函数实现,并返回对自身的引用。

  • register_func函数用于注册一个特化函数,并返回对自身的引用。

  • operator()函数用于通过直接传递未打包格式来调用通用函数。

  • CallPacked函数用于根据当前目标上下文调用相关函数,并将参数以打包格式传递。

  • GetPacked函数用于获取指定当前目标上下文的打包函数。

  • Get函数用于查找或注册给定名称的GenericFunc实例。

  • RegisterGenericFunc函数用于将GenericFunc实例添加到注册表中。

  • operator->运算符用于访问内部节点容器。

  • ContainerType类型别名表示内部节点容器的类型。

  • Manager结构体表示内部管理类。