为算子添加类型关系#
import testing
import numpy as np
import tvm
from tvm import relay
from tvm.ir.attrs import DictAttrs
from tvm.relay.op import op as _op
def infer_mod(mod, annotate_spans=True):
if annotate_spans:
mod = relay.transform.AnnotateSpans()(mod)
mod = relay.transform.InferType()(mod)
return mod
def infer_expr(expr):
relay.transform.InferTypeLocal(expr)
return expr
def assert_has_type(expr, typ, mod=None):
if not mod:
mod = tvm.IRModule({})
mod["main"] = expr
mod = infer_mod(mod)
checked_expr = mod["main"]
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
tvm.ir.Op.add_type_rel??
Signature: tvm.ir.Op.add_type_rel(self, rel_name, type_rel_func=None)
Source:
def add_type_rel(self, rel_name, type_rel_func=None):
"""Attach the type function corresponding to the return type.
Parameters
----------
rel_name : str
The type relation name to register.
type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type]
The backing relation function which can solve an arbitrary relation on variables.
Differences with type_rel_func in C++:
1) When type_rel_func is not None
a) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to
calling convention of relay type system.
b) type_rel_func returns output argument's type, return None means can't
infer output's type.
c) only support single output operators for now, the last argument is output tensor.
2) when type_rel_func is None, will call predefined type_rel_funcs in relay
according to ``tvm.relay.type_relation.`` + rel_name.
"""
_ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)
File: /media/pc/data/lxw/ai/tvm/python/tvm/ir/op.py
Type: function
tvm.ir.Op.add_type_rel()
函数的作用是将返回类型对应的类型函数附加到关系名称上。
参数:
rel_name
:str
,要注册的类型关系名称。type_rel_func
: 可选的函数,接受参数列表和属性作为输入,返回类型。该函数可以解决变量上的任意关系。与 C++ 中的type_rel_func
的区别如下:当
type_rel_func
不为None
时:C++ 端的
OpAddTypeRel
将使用TypeReporter
调整type_rel_func
以适应relay
类型系统的调用约定。type_rel_func
返回输出参数的类型,返回None
表示无法推断输出的类型。目前仅支持单个输出的算子,最后一个参数是输出张量。
当
type_rel_func
为None
时,将根据tvm.relay.type_relation.
+rel_name
调用预定义的relay
中的type_rel_funcs
。
自定义算子类型推断#
op_name = "custom_log"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)
def clog(x):
return relay.Call(_op.get(op_name), [x])
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
print(f)
fn (%x: Tensor[(10, 10), float32]) {
let %t1 = custom_log(%x);
let %t2 = add(%t1, %x);
%t2
}
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
print(fchecked)
fn (%x: Tensor[(10, 10), float32]) {
let %t1 = custom_log(%x);
let %t2 = add(%t1, %x);
%t2
} /* ty=fn (Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */
mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))
mod.show()
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
let %t1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = custom_log(%x) /* ty=Tensor[(10, 10), float32] */;
let %t2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = add(%t1, %x) /* ty=Tensor[(10, 10), float32] */;
%t2
}
推断广播 custom_op 的类型#
op_name = "custom_broadcast_add"
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Broadcast")
_op.get(op_name).set_support_level(1)
_op.register_stateful(op_name, False)
def broadcast_add(x, y):
return relay.Call(_op.get(op_name), [x, y])
x = relay.var("x", shape=(10, 4))
y = relay.var("y", shape=(5, 10, 1))
z = broadcast_add(x, y)
func = relay.Function([x, y], z)
print(func)
fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) {
custom_broadcast_add(%x, %y)
}
t1 = relay.TensorType((10, 4), "float32")
t2 = relay.TensorType((5, 10, 1), "float32")
t3 = relay.TensorType((5, 10, 4), "float32")
expected_ty = relay.FuncType([t1, t2], t3)
assert_has_type(func, expected_ty)
mod = relay.transform.InferType()(tvm.IRModule.from_expr(func))
mod.show()
def @main(%x: Tensor[(10, 4), float32] /* ty=Tensor[(10, 4), float32] */, %y: Tensor[(5, 10, 1), float32] /* ty=Tensor[(5, 10, 1), float32] */) -> Tensor[(5, 10, 4), float32] {
custom_broadcast_add(%x, %y) /* ty=Tensor[(5, 10, 4), float32] */
}
推断 custom_op 的类型关系#
def custom_log1_rel(arg_types, attrs):
assert len(arg_types) == 1, "type relation arg number mismatch!"
if attrs:
assert isinstance(attrs, DictAttrs)
inputa_type = arg_types[0]
return relay.TensorType(inputa_type.shape, inputa_type.dtype)
op_name = "custom_log1"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
# call customized relation functions
_op.get(op_name).add_type_rel("custom_log1", custom_log1_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)
def clog(x):
return relay.Call(_op.get(op_name), [x])
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))
mod.show()
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
let %t1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = custom_log1(%x) /* ty=Tensor[(10, 10), float32] */;
let %t2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */ = add(%t1, %x) /* ty=Tensor[(10, 10), float32] */;
%t2
}
处理推断 custom_op 的类型关系的异常事件#
参数数量不匹配:
import pytest
def custom_log1_rel(arg_types, attrs):
assert len(arg_types) == 2, "type relation arg number mismatch!"
return None
op_name = "custom_log2"
_op.register(op_name, r"code(cal log of a tensor.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
# call customized relation functions
_op.get(op_name).add_type_rel("custom_log2", custom_log1_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)
def clog(x):
return relay.Call(_op.get(op_name), [x])
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", clog(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
with pytest.raises(AssertionError) as cm:
fchecked = infer_expr(f)
assert "type relation arg number mismatch" in str(cm.execption)
重复注册:
op_name = "custom_log3"
_op.register(op_name, r"code(cal log of a tensor.)code")
with pytest.raises(tvm.error.TVMError) as cm:
_op.register(op_name)
assert "Operator custom_log3 is registered before" in str(cm.execption)