中间表示中的类型节点#
参考:tvm/tests/python/ir/test_ir_type.py
import sys
from pathlib import Path
ROOT = Path(".").resolve().parents[2]
sys.path.extend([f"{ROOT}/tests", f"{ROOT}/src"])
# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
from tools.torch_utils import verify_model
import tvm
def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
assert tvm.ir.structural_equal(back, node, map_free_vars=True)
def test_prim_type():
x = tvm.ir.PrimType("int32")
assert isinstance(x, tvm.ir.PrimType)
assert x.dtype == "int32"
def test_tensor_type_bad_constructor():
try:
x = tvm.ir.TensorType("xx", "xx")
except tvm.error.TVMError:
pass
def test_tensor_type():
shape = tvm.runtime.convert([1, 2, 3])
dtype = "float32"
tt = tvm.ir.TensorType(shape, dtype)
assert tt.dtype == dtype
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)
def test_type_param():
tp = tvm.ir.TypeVar("name", tvm.ir.TypeKind.Type)
assert tp.kind == tvm.ir.TypeKind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)
def test_func_type():
type_params = tvm.runtime.convert([])
type_constraints = tvm.runtime.convert([]) # TODO: fill me in
arg_types = tvm.runtime.convert([])
ret_type = tvm.ir.TensorType((1, 2, 3), "float32")
tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)
def test_tuple_type():
tp = tvm.ir.TypeVar("tp", tvm.ir.TypeKind.Type)
tf = tvm.ir.FuncType([], tvm.ir.TupleType([]), [], [])
tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
fields = tvm.runtime.convert([tp, tf, tt])
tup_ty = tvm.ir.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation():
tp = tvm.ir.TypeVar("tp", tvm.ir.TypeKind.Type)
tf = tvm.ir.FuncType([], None, [], [])
tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
args = tvm.runtime.convert([tp, tf, tt])
num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)