解读 GenericFunc
#
%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm
from tvm.target.generic_func import GenericFunc
/*!
* \brief Generate the strategy of operators. This function is a generic
* function and can be re-defined for different targets.
*
* The function signature of generic function is:
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
* const Type& out_type, const Target& target)
*/
using FTVMStrategy = GenericFunc;
这段代码定义了名为 FTVMStrategy
的类型别名,它表示通用函数。这个通用函数用于生成算子的策略。它的函数签名如下:
OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target)
其中,Attrs
表示属性集合,Array<Tensor>
表示输入张量数组,Type
表示输出类型,Target
表示目标平台。这个函数可以根据不同的目标平台进行重定义。
from tvm.relay.op.op import register_strategy
register_strategy??
Signature: register_strategy(op_name, fstrategy=None, level=10)
Source:
def register_strategy(op_name, fstrategy=None, level=10):
"""Register strategy function for an op.
Parameters
----------
op_name : str
The name of the op.
fstrategy : function (attrs: Attrs, inputs: List[Tensor], out_type: Type,
target:Target) -> OpStrategy
The strategy function. Need to be native GenericFunc.
level : int
The priority level
"""
if not isinstance(fstrategy, GenericFunc):
assert hasattr(fstrategy, "generic_func_node")
fstrategy = fstrategy.generic_func_node
return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type: function
class GenericFuncNode;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class GenericFunc : public ObjectRef {
public:
GenericFunc() {}
explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Set the default function implementaiton.
* \param value The default function
* \param allow_override If true, this call may override a previously registered function. If
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
/*!
* \brief Register a specialized function
* \param tags The tags for this specialization
* \param value The specialized function
* \param allow_override If true, this call may override previously registered tags. If false,
* an error will be logged if the call would override previously registered tags.
* \return reference to self.
*/
TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
const runtime::PackedFunc value, bool allow_override = false);
/*!
* \brief Call generic function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call generic function
* void CallGeneric(GenericFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template <typename... Args>
inline runtime::TVMRetValue operator()(Args&&... args) const;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
/*!
* \brief Get the packed function specified for the current target context.
*/
TVM_DLL PackedFunc GetPacked() const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
* \param name The name of the registered GenericFunc
* \return The GenericFunc instance
*/
TVM_DLL static GenericFunc Get(const std::string& name);
/*!
* \brief Add a GenericFunc instance to the registry
* \param func The GenericFunc instance
* \param name The name of the registered GenericFunc
*/
TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline GenericFuncNode* operator->();
// declare container type
using ContainerType = GenericFuncNode;
// Internal class.
struct Manager;
private:
friend struct Manager;
};
这段代码定义了一个名为GenericFunc
的类,它继承自ObjectRef
。这个类表示一个通用函数,可以针对每个目标平台进行特化。
该类中定义了一些成员函数和变量:
set_default
函数用于设置默认函数实现,并返回对自身的引用。register_func
函数用于注册一个特化函数,并返回对自身的引用。operator()
函数用于通过直接传递未打包格式来调用通用函数。CallPacked
函数用于根据当前目标上下文调用相关函数,并将参数以打包格式传递。GetPacked
函数用于获取指定当前目标上下文的打包函数。Get
函数用于查找或注册给定名称的GenericFunc
实例。RegisterGenericFunc
函数用于将GenericFunc
实例添加到注册表中。operator->
运算符用于访问内部节点容器。ContainerType
类型别名表示内部节点容器的类型。Manager
结构体表示内部管理类。