为算子添加类型关系#

import testing
import numpy as np
import tvm
from tvm import relay
from tvm.ir.attrs import DictAttrs
from tvm.relay.op import op as _op
    
def infer_mod(mod, annotate_spans=True):
    if annotate_spans:
        mod = relay.transform.AnnotateSpans()(mod)

    mod = relay.transform.InferType()(mod)
    return mod


def infer_expr(expr):
    relay.transform.InferTypeLocal(expr)
    return expr


def assert_has_type(expr, typ, mod=None):
    if not mod:
        mod = tvm.IRModule({})

    mod["main"] = expr
    mod = infer_mod(mod)
    checked_expr = mod["main"]
    checked_type = checked_expr.checked_type
    if checked_type != typ:
        raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
tvm.ir.Op.add_type_rel??
Signature: tvm.ir.Op.add_type_rel(self, rel_name, type_rel_func=None)
Source:   
    def add_type_rel(self, rel_name, type_rel_func=None):
        """Attach the type function corresponding to the return type.

        Parameters
        ----------
        rel_name : str
            The type relation name to register.

        type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type]
            The backing relation function which can solve an arbitrary relation on variables.
            Differences with type_rel_func in C++:

            1) When type_rel_func is not None

               a) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to
                  calling convention of relay type system.

               b) type_rel_func returns output argument's type, return None means can't
                  infer output's type.

               c) only support single output operators for now, the last argument is output tensor.

            2) when type_rel_func is None, will call predefined type_rel_funcs in relay
                   according to ``tvm.relay.type_relation.`` + rel_name.

        """
        _ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/ir/op.py
Type:      function

tvm.ir.Op.add_type_rel() 函数的作用是将返回类型对应的类型函数附加到关系名称上。

参数:

  • rel_name: str,要注册的类型关系名称。

  • type_rel_func: 可选的函数,接受参数列表和属性作为输入,返回类型。该函数可以解决变量上的任意关系。与 C++ 中的 type_rel_func 的区别如下:

    1. type_rel_func 不为 None 时:

      • C++ 端的 OpAddTypeRel 将使用 TypeReporter 调整 type_rel_func 以适应 relay 类型系统的调用约定。

      • type_rel_func 返回输出参数的类型,返回 None 表示无法推断输出的类型。

      • 目前仅支持单个输出的算子,最后一个参数是输出张量。

    2. type_rel_funcNone 时,将根据 tvm.relay.type_relation. + rel_name 调用预定义的 relay 中的 type_rel_funcs

自定义算子类型推断#

op_name = "custom_log"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)
def clog(x):
    return relay.Call(_op.get(op_name), [x])
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
print(f)
fn (%x: Tensor[(10, 10), float32]) {
  let %t1 = custom_log(%x);
  let %t2 = add(%t1, %x);
  %t2
}
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
print(fchecked)
fn (%x: Tensor[(10, 10), float32]) {
  let %t1 = custom_log(%x);
  let %t2 = add(%t1, %x);
  %t2
} /* ty=fn (Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */
mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))
mod.show()
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  let %t1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = custom_log(%x) /* ty=Tensor[(10, 10), float32] */;
  let %t2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = add(%t1, %x) /* ty=Tensor[(10, 10), float32] */;
  %t2
}

推断广播 custom_op 的类型#

op_name = "custom_broadcast_add"
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Broadcast")
_op.get(op_name).set_support_level(1)
_op.register_stateful(op_name, False)
def broadcast_add(x, y):
    return relay.Call(_op.get(op_name), [x, y])

x = relay.var("x", shape=(10, 4))
y = relay.var("y", shape=(5, 10, 1))
z = broadcast_add(x, y)
func = relay.Function([x, y], z)
print(func)
fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) {
  custom_broadcast_add(%x, %y)
}
t1 = relay.TensorType((10, 4), "float32")
t2 = relay.TensorType((5, 10, 1), "float32")
t3 = relay.TensorType((5, 10, 4), "float32")
expected_ty = relay.FuncType([t1, t2], t3)
assert_has_type(func, expected_ty)
mod = relay.transform.InferType()(tvm.IRModule.from_expr(func))
mod.show()
def @main(%x: Tensor[(10, 4), float32] /* ty=Tensor[(10, 4), float32] */, %y: Tensor[(5, 10, 1), float32] /* ty=Tensor[(5, 10, 1), float32] */) -> Tensor[(5, 10, 4), float32] {
  custom_broadcast_add(%x, %y) /* ty=Tensor[(5, 10, 4), float32] */
}

推断 custom_op 的类型关系#

def custom_log1_rel(arg_types, attrs):
    assert len(arg_types) == 1, "type relation arg number mismatch!"
    if attrs:
        assert isinstance(attrs, DictAttrs)
    inputa_type = arg_types[0]
    return relay.TensorType(inputa_type.shape, inputa_type.dtype)

op_name = "custom_log1"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
# call customized relation functions
_op.get(op_name).add_type_rel("custom_log1", custom_log1_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

def clog(x):
    return relay.Call(_op.get(op_name), [x])

tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))
mod.show()
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  let %t1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = custom_log1(%x) /* ty=Tensor[(10, 10), float32] */;
  let %t2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = add(%t1, %x) /* ty=Tensor[(10, 10), float32] */;
  %t2
}

处理推断 custom_op 的类型关系的异常事件#

参数数量不匹配:

import pytest
def custom_log1_rel(arg_types, attrs):
    assert len(arg_types) == 2, "type relation arg number mismatch!"
    return None

op_name = "custom_log2"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
# call customized relation functions
_op.get(op_name).add_type_rel("custom_log2", custom_log1_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

def clog(x):
    return relay.Call(_op.get(op_name), [x])

tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
with pytest.raises(AssertionError) as cm:
    fchecked = infer_expr(f)
    assert "type relation arg number mismatch" in str(cm.execption)

重复注册:

op_name = "custom_log3"
_op.register(op_name, r"code(cal log of a tensor.)code")
with pytest.raises(tvm.error.TVMError) as cm:
    _op.register(op_name)
    assert "Operator custom_log3 is registered before" in str(cm.execption)