# 为算子添加类型关系

In [1]:
import testing

In [2]:
import numpy as np
import tvm
from tvm import relay
from tvm.ir.attrs import DictAttrs
from tvm.relay.op import op as _op
    
def infer_mod(mod, annotate_spans=True):
    if annotate_spans:
        mod = relay.transform.AnnotateSpans()(mod)

    mod = relay.transform.InferType()(mod)
    return mod


def infer_expr(expr):
    relay.transform.InferTypeLocal(expr)
    return expr


def assert_has_type(expr, typ, mod=None):
    if not mod:
        mod = tvm.IRModule({})

    mod["main"] = expr
    mod = infer_mod(mod)
    checked_expr = mod["main"]
    checked_type = checked_expr.checked_type
    if checked_type != typ:
        raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))

In [3]:
tvm.ir.Op.add_type_rel??

[0;31mSignature:[0m [0mtvm[0m[0;34m.[0m[0mir[0m[0;34m.[0m[0mOp[0m[0;34m.[0m[0madd_type_rel[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrel_name[0m[0;34m,[0m [0mtype_rel_func[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0madd_type_rel[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrel_name[0m[0;34m,[0m [0mtype_rel_func[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34m"""Attach the type function corresponding to the return type.[0m
[0;34m[0m
[0;34m        Parameters[0m
[0;34m        ----------[0m
[0;34m        rel_name : str[0m
[0;34m            The type relation name to register.[0m
[0;34m[0m
[0;34m        type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type][0m
[0;34m            The backing relation function which can solve an arbitrary relation on variables.[0m
[0;34m            Differences with type_rel_func in C++:

{meth}`tvm.ir.Op.add_type_rel` 函数的作用是将返回类型对应的类型函数附加到关系名称上。

参数：
- `rel_name`: `str`，要注册的类型关系名称。
- `type_rel_func`: 可选的函数，接受参数列表和属性作为输入，返回类型。该函数可以解决变量上的任意关系。与 C++ 中的 `type_rel_func` 的区别如下：
    1. 当 `type_rel_func` 不为 `None` 时：
        - C++ 端的 `OpAddTypeRel` 将使用 `TypeReporter` 调整 `type_rel_func` 以适应 `relay` 类型系统的调用约定。
        - `type_rel_func` 返回输出参数的类型，返回 `None` 表示无法推断输出的类型。
        - 目前仅支持单个输出的算子，最后一个参数是输出张量。
    2. 当 `type_rel_func` 为 `None` 时，将根据 ``tvm.relay.type_relation.`` + `rel_name` 调用预定义的 `relay` 中的 `type_rel_funcs`。

## 自定义算子类型推断

In [4]:
op_name = "custom_log"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

In [5]:
def clog(x):
    return relay.Call(_op.get(op_name), [x])

In [6]:
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())

In [7]:
print(f)

fn (%x: Tensor[(10, 10), float32]) {
  let %t1 = custom_log(%x);
  let %t2 = add(%t1, %x);
  %t2
}


In [8]:
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)

In [9]:
print(fchecked)

fn (%x: Tensor[(10, 10), float32]) {
  let %t1 = custom_log(%x);
  let %t2 = add(%t1, %x);
  %t2
} /* ty=fn (Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */


In [10]:
mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))
mod.show()

## 推断广播 custom_op 的类型

In [11]:
op_name = "custom_broadcast_add"
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Broadcast")
_op.get(op_name).set_support_level(1)
_op.register_stateful(op_name, False)

In [12]:
def broadcast_add(x, y):
    return relay.Call(_op.get(op_name), [x, y])

x = relay.var("x", shape=(10, 4))
y = relay.var("y", shape=(5, 10, 1))
z = broadcast_add(x, y)
func = relay.Function([x, y], z)

In [13]:
print(func)

fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) {
  custom_broadcast_add(%x, %y)
}


In [14]:
t1 = relay.TensorType((10, 4), "float32")
t2 = relay.TensorType((5, 10, 1), "float32")
t3 = relay.TensorType((5, 10, 4), "float32")
expected_ty = relay.FuncType([t1, t2], t3)
assert_has_type(func, expected_ty)

In [15]:
mod = relay.transform.InferType()(tvm.IRModule.from_expr(func))
mod.show()

## 推断 custom_op 的类型关系

In [16]:
def custom_log1_rel(arg_types, attrs):
    assert len(arg_types) == 1, "type relation arg number mismatch!"
    if attrs:
        assert isinstance(attrs, DictAttrs)
    inputa_type = arg_types[0]
    return relay.TensorType(inputa_type.shape, inputa_type.dtype)

op_name = "custom_log1"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
# call customized relation functions
_op.get(op_name).add_type_rel("custom_log1", custom_log1_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

def clog(x):
    return relay.Call(_op.get(op_name), [x])

tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)

In [17]:
mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))
mod.show()

## 处理推断 custom_op 的类型关系的异常事件

参数数量不匹配：

In [18]:
import pytest
def custom_log1_rel(arg_types, attrs):
    assert len(arg_types) == 2, "type relation arg number mismatch!"
    return None

op_name = "custom_log2"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
# call customized relation functions
_op.get(op_name).add_type_rel("custom_log2", custom_log1_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

def clog(x):
    return relay.Call(_op.get(op_name), [x])

tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
with pytest.raises(AssertionError) as cm:
    fchecked = infer_expr(f)
    assert "type relation arg number mismatch" in str(cm.execption)

重复注册：

In [20]:
op_name = "custom_log3"
_op.register(op_name, r"code(cal log of a tensor.)code")
with pytest.raises(tvm.error.TVMError) as cm:
    _op.register(op_name)
    assert "Operator custom_log3 is registered before" in str(cm.execption)