解读类型推断#
import testing
import pytest
import tvm
from tvm import IRModule, relay
from tvm.relay import op, transform
from tvm.relay.op import op as _op
import numpy as np
def infer_mod(mod, annotate_spans=True):
if annotate_spans:
mod = relay.transform.AnnotateSpans()(mod)
mod = transform.InferType()(mod)
return mod
def infer_expr(expr):
transform.InferTypeLocal(expr)
return expr
def assert_has_type(expr, typ, mod=None):
if not mod:
mod = tvm.IRModule({})
mod["main"] = expr
mod = infer_mod(mod)
checked_expr = mod["main"]
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
def initialize_box_adt(mod):
# initializes simple ADT for tests
box = relay.GlobalTypeVar("box")
tv = relay.TypeVar("tv")
constructor = relay.Constructor("constructor", [tv], box)
data = relay.TypeData(box, [tv], [constructor])
mod[box] = data
return box, constructor
测试单态 let 类型推断#
monomorphic:单态
程序:let %x = 1; %x
sb = relay.ScopeBuilder()
x = relay.var("x", dtype="float64", shape=())
x = sb.let(x, relay.const(1.0, "float64"))
sb.ret(x)
out = sb.get()
print(out)
let %x: float64 = 1f64;
%x
xchecked = infer_expr(out)
assert xchecked.checked_type == relay.scalar_type("float64")
测试单个算子类型推断#
x = relay.var("x", shape=[])
func = relay.Function([x], op.log(x))
ttype = relay.TensorType([], dtype="float32")
assert_has_type(func, relay.FuncType([ttype], ttype))
print(func)
fn (%x: float32) {
log(%x)
}
其他类型推断#
def test_add_broadcast_op():
"""
Program:
fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
-> Tensor[(5, 10, 4), float32] {
%x + %y
}
"""
x = relay.var("x", shape=(10, 4))
y = relay.var("y", shape=(5, 10, 1))
z = x + y
func = relay.Function([x, y], z)
t1 = relay.TensorType((10, 4), "float32")
t2 = relay.TensorType((5, 10, 1), "float32")
t3 = relay.TensorType((5, 10, 4), "float32")
expected_ty = relay.FuncType([t1, t2], t3)
assert_has_type(func, expected_ty)
def test_dual_op():
"""Program:
fn (%x : Tensor[(10, 10), float32]) {
let %t1 = log(x);
let %t2 = add(%t1, %x);
%t1
}
"""
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", relay.log(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
def test_decl():
"""Program:
def @f(%x : Tensor[(10, 10), float32]) {
log(%x)
}
"""
tp = relay.TensorType((10, 10))
x = relay.var("x", tp)
f = relay.Function([x], relay.log(x))
fchecked = infer_expr(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
def test_recursion():
"""
Program:
def @f(%n: int32, %data: float32) -> float32 {
if (%n == 0) {
%data
} else {
@f(%n - 1, log(%data))
}
}
"""
sb = relay.ScopeBuilder()
f = relay.GlobalVar("f")
ti32 = relay.scalar_type("int32")
tf32 = relay.scalar_type("float32")
n = relay.var("n", ti32)
data = relay.var("data", tf32)
with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
sb.ret(data)
with sb.else_scope():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
mod = tvm.IRModule()
mod[f] = relay.Function([n, data], sb.get())
mod = infer_mod(mod)
assert "@f(%1, %2)" in mod.astext()
assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32)
def test_incomplete_call():
tt = relay.scalar_type("int32")
x = relay.var("x", tt)
f_type = relay.FuncType([tt], tt)
f = relay.var("f")
func = relay.Function([x, f], relay.Call(f, [x]), tt)
ft = infer_expr(func)
assert ft.checked_type == relay.FuncType([tt, f_type], tt)
def test_higher_order_argument():
a = relay.TypeVar("a")
x = relay.Var("x", a)
id_func = relay.Function([x], x, a, [a])
b = relay.TypeVar("b")
f = relay.Var("f", relay.FuncType([b], b))
y = relay.Var("y", b)
ho_func = relay.Function([f, y], f(y), b, [b])
# id func should be an acceptable argument to the higher-order
# function even though id_func takes a type parameter
ho_call = ho_func(id_func, relay.const(0, "int32"))
hc = infer_expr(ho_call)
expected = relay.scalar_type("int32")
assert hc.checked_type == expected
def test_higher_order_return():
a = relay.TypeVar("a")
x = relay.Var("x", a)
id_func = relay.Function([x], x, a, [a])
b = relay.TypeVar("b")
nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
ft = infer_expr(nested_id)
assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
def test_higher_order_nested():
a = relay.TypeVar("a")
x = relay.Var("x", a)
id_func = relay.Function([x], x, a, [a])
choice_t = relay.FuncType([], relay.scalar_type("bool"))
f = relay.Var("f", choice_t)
b = relay.TypeVar("b")
z = relay.Var("z")
top = relay.Function(
[f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b]
)
expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
ft = infer_expr(top)
assert ft.checked_type == expected
def test_tuple():
tp = relay.TensorType((10,))
x = relay.var("x", tp)
res = relay.Tuple([x, x])
assert infer_expr(res).checked_type == relay.TupleType([tp, tp])
def test_ref():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
r = relay.RefCreate(x)
st = relay.scalar_type("float32")
assert infer_expr(r).checked_type == relay.RefType(st)
g = relay.RefRead(r)
assert infer_expr(g).checked_type == st
w = relay.RefWrite(r, y)
assert infer_expr(w).checked_type == relay.TupleType([])
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = infer_expr(y)
assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
assert yy.checked_type == relay.scalar_type("float32")
assert x.vid.same_as(yy.args[0].vid)
def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
# InferTypeLocal does not support populating the type_args field
mod = infer_mod(IRModule.from_expr(z))
mod = infer_mod(mod, annotate_spans=False)
ty_args = mod["main"].body.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"
sh1 = ty_args[0].shape
sh2 = ty_args[1].shape
assert sh1[0].value == 10
assert sh1[1].value == 10
assert sh2[0].value == 1
assert sh2[1].value == 10
def test_global_var_recursion():
mod = tvm.IRModule({})
gv = relay.GlobalVar("main")
x = relay.var("x", shape=[])
tt = relay.scalar_type("float32")
func = relay.Function([x], relay.Call(gv, [x]), tt)
mod[gv] = func
mod = infer_mod(mod)
func_ty = mod["main"].checked_type
assert func_ty == relay.FuncType([tt], tt)
def test_equal():
i = relay.var("i", shape=[], dtype="int32")
eq = op.equal(i, relay.const(0, dtype="int32"))
func = relay.Function([i], eq)
ft = infer_expr(func)
expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool"))
assert ft.checked_type == expected
assert ft.checked_type == relay.FuncType(
[relay.scalar_type("int32")], relay.scalar_type("bool")
)
def test_constructor_type():
mod = tvm.IRModule()
box, constructor = initialize_box_adt(mod)
a = relay.TypeVar("a")
x = relay.Var("x", a)
func = relay.Function([x], constructor(x), box(a), [a])
mod["main"] = func
mod = infer_mod(mod)
func_ty = mod["main"].checked_type
box = mod.get_global_type_var("box")
expected = relay.FuncType([a], box(a), [a])
assert func_ty == expected
def test_constructor_call():
mod = tvm.IRModule()
box, constructor = initialize_box_adt(mod)
box_unit = constructor(relay.Tuple([]))
box_constant = constructor(relay.const(0, "float32"))
func = relay.Function([], relay.Tuple([box_unit, box_constant]))
mod["main"] = func
mod = infer_mod(mod)
ret_type = mod["main"].checked_type.ret_type.fields
# NB(@jroesch): when we annotate spans the ast fragments before
# annotation the previous fragments will no longer be directly equal.
box = mod.get_global_type_var("box")
expected1 = box(relay.TupleType([]))
expected2 = box(relay.TensorType((), "float32"))
assert ret_type[0] == expected1
assert ret_type[1] == expected2
def test_adt_match():
mod = tvm.IRModule()
box, constructor = initialize_box_adt(mod)
v = relay.Var("v", relay.TensorType((), "float32"))
match = relay.Match(
constructor(relay.const(0, "float32")),
[
relay.Clause(
relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])
),
# redundant but shouldn't matter to typechecking
relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
],
)
func = relay.Function([], match)
mod["main"] = func
mod = infer_mod(mod)
actual = mod["main"].checked_type.ret_type
assert actual == relay.TupleType([])
def test_adt_match_type_annotations():
mod = tvm.IRModule()
box, constructor = initialize_box_adt(mod)
# the only type annotation is inside the match pattern var
# but that should be enough info
tt = relay.TensorType((2, 2), "float32")
x = relay.Var("x")
mv = relay.Var("mv", tt)
match = relay.Match(
constructor(x),
[
relay.Clause(
relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])
)
],
)
mod["main"] = relay.Function([x], match)
mod = infer_mod(mod)
ft = mod["main"].checked_type
assert ft == relay.FuncType([tt], relay.TupleType([]))
def test_let_polymorphism():
id = relay.Var("id")
xt = relay.TypeVar("xt")
x = relay.Var("x", xt)
body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = infer_expr(body)
int32 = relay.TensorType((), "int32")
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
def test_type_arg_infer():
code = """
#[version = "0.0.5"]
def @id[A](%x: A) -> A {
%x
}
def @main(%f: float32) -> float32 {
@id(%f)
}
"""
mod = tvm.relay.fromtext(code)
mod = transform.InferType()(mod)
tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
def test_dynamic_function():
dy_tt = relay.TensorType([relay.Any()], "float32")
s_tt = relay.TensorType([10], "float32")
x = relay.Var("x", dy_tt)
f = relay.Function([x], x + x)
y = relay.Var("y", s_tt)
c = f(y)
mod = tvm.IRModule()
mod["main"] = relay.Function([y], c)
mod = transform.InferType()(mod)
assert mod["main"].params[0].checked_type == s_tt
data = relay.var(
"data", shape=(relay.Any(), relay.Any(), relay.Any(), relay.Any()), dtype="float32"
)
weigth = relay.const(np.full((16, 16, 3, 3), 0.25), dtype="float32")
x = relay.nn.conv2d(data, weigth, kernel_size=(3, 3), channels=16, groups=2)
mod = tvm.IRModule.from_expr(x)
mod = transform.InferType()(mod)
if __name__ == "__main__":
tvm.testing.main()
测试参数约简以推断返回类型#
x_shape = (1, 1)
broadcast_shape = [1, 1]
shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))]
# Testing with argmax (argreduce)
for (sdtype, conv) in shape_dtypes:
x = relay.var("data", relay.TensorType(x_shape, "float32"))
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype))
argmax = relay.op.argmax(broadcast_to, axis=[1])
f = relay.Function([x], argmax)
assert_has_type(
f,
relay.FuncType(
[relay.TensorType(broadcast_shape, "float32")],
relay.TensorType([conv(1)], dtype=sdtype),
),
)
# Testing with argmin
for (sdtype, conv) in shape_dtypes:
x = relay.var("data", relay.TensorType(x_shape, "float32"))
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype))
argmin = relay.op.argmin(broadcast_to, axis=[1])
f = relay.Function([x], argmin)
assert_has_type(
f,
relay.FuncType(
[relay.TensorType(broadcast_shape, "float32")],
relay.TensorType([conv(1)], dtype=sdtype),
),
)