注册算子#
import tvm
def register(op_name, describe=""):
"""Get the Op for a given name.
when the op_name is not registered, create a new empty op with the given name.
when the op_name has been registered, abort with an error message.
Parameters
----------
op_name : str
The operator name
describe : Optional[str]
The operator description
"""
tvm.ir._ffi_api.RegisterOp(op_name, describe)
1. 获取或创建算子#
op_name = "my.operator"
try:
op = tvm.ir.Op.get(op_name)
except:
op = register(op_name)
2. 设置算子属性#
tvm.ir.Op.get(op_name).set_num_inputs(1)
tvm.ir.Op.get(op_name).add_argument("data", "Tensor", "输入数据")
3. 定义类型关系#
def my_op_type_rel(types, num_inputs, attrs, reporter):
# 验证输入类型并关联输出类型
data = types[0].as_tensor_type()
if data is None:
return False
reporter.assoc(types[1], data)
return True
tvm.ir.Op.get(op_name).add_type_rel("MyOpTypeRel", my_op_type_rel)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[4], line 8
6 reporter.assoc(types[1], data)
7 return True
----> 8 tvm.ir.Op.get(op_name).add_type_rel("MyOpTypeRel", my_op_type_rel)
AttributeError: 'Op' object has no attribute 'add_type_rel'
4. 设置计算函数#
def my_op_compute(attrs, inputs, out_type):
# 从TVMScript模块获取计算函数
mod = my_op_module()
# 创建外部函数调用
out = te.extern(
[inputs[0].shape],
[inputs[0]],
lambda ins, outs: mod["my_op_func"](ins[0], outs[0]),
name="my_op",
dtype=inputs[0].dtype
)
return [out]
tvm.ir.Op.get(op_name).set_attr("FTVMCompute", my_op_compute)