注册算子#

import tvm

def register(op_name, describe=""):
    """Get the Op for a given name.
    when the op_name is not registered, create a new empty op with the given name.
    when the op_name has been registered, abort with an error message.

    Parameters
    ----------
    op_name : str
        The operator name

    describe : Optional[str]
        The operator description
    """

    tvm.ir._ffi_api.RegisterOp(op_name, describe)

1. 获取或创建算子#

op_name = "my.operator"
try:
    op = tvm.ir.Op.get(op_name)
except:
    op = register(op_name)

2. 设置算子属性#

tvm.ir.Op.get(op_name).set_num_inputs(1)
tvm.ir.Op.get(op_name).add_argument("data", "Tensor", "输入数据")

3. 定义类型关系#

def my_op_type_rel(types, num_inputs, attrs, reporter):
    # 验证输入类型并关联输出类型
    data = types[0].as_tensor_type()
    if data is None:
        return False
    reporter.assoc(types[1], data)
    return True
tvm.ir.Op.get(op_name).add_type_rel("MyOpTypeRel", my_op_type_rel)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 8
      6     reporter.assoc(types[1], data)
      7     return True
----> 8 tvm.ir.Op.get(op_name).add_type_rel("MyOpTypeRel", my_op_type_rel)

AttributeError: 'Op' object has no attribute 'add_type_rel'

4. 设置计算函数#

def my_op_compute(attrs, inputs, out_type):
    # 从TVMScript模块获取计算函数
    mod = my_op_module()
    # 创建外部函数调用
    out = te.extern(
        [inputs[0].shape],
        [inputs[0]],
        lambda ins, outs: mod["my_op_func"](ins[0], outs[0]),
        name="my_op",
        dtype=inputs[0].dtype
    )
    return [out]
tvm.ir.Op.get(op_name).set_attr("FTVMCompute", my_op_compute)