解读 _test_type_solver
#
参考:tvm/tests/python/relay/test_type_solver.py
& tvm/src/relay/analysis/type_solver.h
import testing
import tvm
from tvm import relay
from tvm.relay import testing
import numpy as np
def make_rel(name, args, num_inputs=None, attrs=None):
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation." + name)
if num_inputs is None:
num_inputs = len(args) - 1
return relay.ty.TypeRelation(func, args, num_inputs, attrs)
relay.ty.TypeRelation??
Init signature: relay.ty.TypeRelation(func, args, num_inputs, attrs)
Source:
@tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.
TypeRelation is more generalized than TypeCall as it allows inference
of both inputs and outputs.
Parameters
----------
func : EnvFunc
User defined relation function.
args : [tvm.ir.Type]
List of types to the func.
num_inputs : int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
Returns
-------
type_relation : tvm.ir.TypeRelation
The type relation.
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_ffi_api.TypeRelation, func, args, num_inputs, attrs)
File: /media/pc/data/lxw/ai/tvm/python/tvm/ir/type_relation.py
Type: type
Subclasses:
TypeRelation
类,继承自 TypeConstraint
。它表示用户定义的类型关系,是类型上的输入输出关系。
TypeRelation
比 TypeCall
更通用,因为它允许推断输入和输出。
参数:
func
:EnvFunc
,用户定义的关系函数。args
:[tvm.ir.Type]
,要传递给func
的类型列表。num_inputs
:int
,args
中的输入参数数量,这作为类型推断的提示。attrs
:Attrs
,附加到关系信息的属性。
返回值:
type_relation
:tvm.ir.TypeRelation
,类型关系。
def make_solver():
solver = relay.analysis._ffi_api._test_type_solver()
solver.Solve = solver("Solve")
solver.Unify = solver("Unify")
solver.Resolve = solver("Resolve")
solver.AddConstraint = solver("AddConstraint")
def gen_type(name, args, out=None):
out = out if out else relay.ty.IncompleteType()
solver.AddConstraint(make_rel(name, args + [out]))
return out
solver.gen_type = gen_type
return solver
/*!
* \brief Interface of type solver used in type inference.
*
* TypeSolver works on a list of constraints among incomplete types.
* The user will populate the constraints by AddConstraint and Assign.
* Then we can call Solve to trying to resolve the unknown.
*
* This can be viewed as "type program(computational graph)" of types, where
* the type constraint are operators of the graph and the incomplete
* types are intermediate value of the graph.
* If all the input types are concretely known, we should be able to
* just run a forward pass on the "type program" to get all the types.
*
* The list of constraints representation means we are storing it as a bipartite
* graph instead of a DAG. This is because some constraints might go both direction.
* TypeSolver could take advantage of bidirectional constraints to deduce input
* value given output ones. Never-the-less, we should keep in mind that
* there is a "forward direction" that the TypeSolver should take advantage of.
*/
class TypeSolver {
public:
TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx);
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
* \param constraint The constraint to be added.
* \param location The location at which the constraint was incurred.
*/
void AddConstraint(const TypeConstraint& constraint, const Span& span);
/*!
* \brief Resolve type to the solution type in the solver.
* \param type The type to be resolved.
* \return The resolved type.
*/
Type Resolve(const Type& type);
/*!
* \brief Start to solve the types using the current known information.
* \return Whether all the incomplete types has been fully resolved.
*/
bool Solve();
/*!
* \brief Unify lhs and rhs.
* \param lhs The left operand.
* \param rhs The right operand
* \param location The location at which the unification problem arose.
*/
Type Unify(const Type& lhs, const Type& rhs, const Span& span, bool assign_lhs = true,
bool assign_rhs = true);
/*!
* \brief Report a diagnostic.
* \param diag The diagnostic to report.
*/
void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }
这段代码是类型推断器(TypeSolver)的接口定义。类型推断器用于解决类型之间的约束关系,以确定未知类型的具体值。
该接口包含以下方法:
AddConstraint
: 向求解器中添加类型约束。参数包括要添加的约束和约束发生的位置。Resolve
: 将给定的类型解析为解决方案类型。参数是要解析的类型,返回解析后的类型。Solve
: 使用当前已知信息开始解决类型。返回是否所有不完整类型都已完全解析。Unify
: 统一 lhs 和 rhs。参数包括左操作数、右操作数以及问题出现的位置。还可以指定是否分配给 lhs 和 rhs。Emit
: 报告诊断信息。参数是要报告的诊断信息。
广播的类型推断(solver)#
solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
t1 = relay.ty.TensorType((10, 1), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t2 = solver.gen_type("Broadcast", [t0, t1])
t3 = solver.gen_type("Identity", [t2])
t4 = solver.gen_type("Broadcast", [t3, tc])
assert solver.Solve()
assert solver.Resolve(t2) == relay.ty.TensorType((10, 20), "float32")
assert solver.Resolve(t4) == relay.ty.TensorType((10, 10, 20), "float32")
其他类型推断(solver)#
def test_backward_solving():
solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t1 = relay.ty.IncompleteType()
t3 = solver.gen_type("Broadcast", [t0, t1])
t2 = solver.gen_type("Identity", [t1], out=tc)
assert solver.Solve()
assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32")
def test_unify_tuple():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.TensorType((10, 20), "float32")
tup1 = relay.ty.TupleType([t1, t2])
tup2 = relay.ty.TupleType([t3, t3])
unified = solver.Unify(tup1, tup2)
assert unified == tup2
def test_unify_global_type_var():
# should only be able to unify if they're the same
solver = make_solver()
gtv = relay.GlobalTypeVar("gtv")
unified = solver.Unify(gtv, gtv)
assert unified == gtv
def test_unify_typecall():
solver = make_solver()
gtv = relay.GlobalTypeVar("gtv")
# yeah, typecalls are shaped like tuples so the same
# tests work out
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.TensorType((10, 20), "float32")
tc1 = relay.ty.TypeCall(gtv, [t1, t2])
tc2 = relay.ty.TypeCall(gtv, [t3, t3])
unified = solver.Unify(tc1, tc2)
assert unified == tc2
def test_unify_functype():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
unit = relay.ty.TupleType([])
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10,), "float32")
ft1 = relay.ty.FuncType([t1, t2], t3)
ft2 = relay.ty.FuncType([tensor1, tensor2], unit)
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_recursive_unify():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tensor1 = relay.ty.TensorType((10, 10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 20), "float32")
tensor3 = relay.ty.TensorType((10,), "float32")
tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2])
tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2])
ft1 = relay.ty.FuncType([tup1, t3], t3)
ft2 = relay.ty.FuncType([tup2, tensor3], tensor3)
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_vars_under_tuples():
solver = make_solver()
t1 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([t1, t1])
unified = solver.Unify(tup1, tup1)
assert unified == tup1
t2 = relay.ty.IncompleteType()
tup2 = relay.ty.TupleType([t2, t2])
tup3 = relay.ty.TupleType([t1, t2])
tup4 = relay.ty.TupleType([t2, t1])
unified = solver.Unify(tup3, tup4)
assert unified == tup1 or unified == tup2
def test_binding_over_typevars():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
a = relay.ty.TypeVar("a")
b = relay.ty.TypeVar("b")
c = relay.ty.TypeVar("c")
d = relay.ty.TypeVar("d")
ft1 = relay.ty.FuncType([t1], t2, [c, d])
ft2 = relay.ty.FuncType([a], b, [a, b])
unified = solver.Unify(ft1, ft2)
assert unified == solver.Resolve(ft1)
def test_recursive_backward_solving():
solver = make_solver()
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
tensor3 = relay.ty.TensorType((10,), "float32")
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3])
solver.gen_type("Identity", [tup1], out=tup2)
assert solver.Solve()
assert solver.Resolve(tup2) == tup1
def test_backward_solving_after_child_update():
solver = make_solver()
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([t1, t2])
tup2 = relay.ty.TupleType([t1, t3])
tup_concrete = relay.ty.TupleType([tensor1, tensor2])
t4 = solver.gen_type("Identity", [tup1])
t5 = solver.gen_type("Identity", [tup2])
solver.gen_type("Identity", [t4], out=t5)
assert solver.Solve()
assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2
assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2
assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2
# updating the variables *inside* tup1 and tup2 should update t4 and t5
solver.gen_type("Identity", [t1], out=tensor1)
solver.gen_type("Identity", [t2], out=tensor2)
assert solver.Solve()
assert solver.Resolve(t4) == tup_concrete
assert solver.Resolve(t5) == tup_concrete
def test_unify_quantified_funcs():
solver = make_solver()
a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
ft1 = relay.FuncType([a, b], c, [a, b, c])
ft2 = relay.FuncType([a, a], a, [a])
unified = solver.Unify(ft1, ft2)
assert unified == ft2
ft3 = relay.FuncType([a], a, [a])
ft4 = relay.FuncType([b], c, [b, c])
unified = solver.Unify(ft3, ft4)
assert unified == ft3
def test_unify_quantified_func_and_concrete():
solver = make_solver()
a, b = relay.TypeVar("a"), relay.TypeVar("b")
ft1 = relay.FuncType([a], b, [a, b])
ft2 = relay.FuncType([b], relay.TupleType([]), [b])
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_quantified_funcs_nesting():
solver = make_solver()
a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
ft2 = relay.FuncType([a, relay.TupleType([a, a])], relay.TupleType([a, a, a]), [a])
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_quantified_funcs_var_order():
solver = make_solver()
a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
ft2 = relay.FuncType([a, relay.TupleType([a, c])], relay.TupleType([a, a, c]), [a, c])
# unified = solver.Unify(ft1, ft2) # crashes here but it shouldn't
# assert unified == ft2
不兼容的类型推断#
import pytest
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_tuple_unification():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
tensor1 = relay.ty.TensorType((1, 2, 3), "float32")
tensor2 = relay.ty.TensorType((2, 3), "float32")
tensor3 = relay.ty.TensorType((3,), "float32")
tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2])
tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
solver.Unify(tup1, tup2)
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_bad_recursive_unification():
solver = make_solver()
t1 = relay.ty.IncompleteType()
solver.Unify(t1, relay.ty.TupleType([t1, t1]))
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_unify_invalid_global_typevars():
solver = make_solver()
gtv1 = relay.GlobalTypeVar("gtv1")
gtv2 = relay.GlobalTypeVar("gtv2")
solver.Unify(gtv1, gtv2)
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_typecall_var_unification():
solver = make_solver()
gtv1 = relay.GlobalTypeVar("gtv1")
gtv2 = relay.GlobalTypeVar("gtv2")
t1 = relay.IncompleteType()
t2 = relay.IncompleteType()
tc1 = relay.TypeCall(gtv1, [t1])
tc2 = relay.TypeCall(gtv2, [t2])
solver.Unify(tc1, tc2)
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_typecall_args_unification():
solver = make_solver()
gtv = relay.GlobalTypeVar("gtv1")
t1 = relay.IncompleteType()
t2 = relay.IncompleteType()
tensor1 = relay.TensorType((1, 2, 3), "float32")
tensor2 = relay.TensorType((2, 3), "float32")
tensor3 = relay.TensorType((3,), "float32")
tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2])
tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3])
solver.Unify(tc1, tc2)
@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_quantified_func_unification():
solver = make_solver()
a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
ft1 = relay.FuncType([a, b], c, [a, b, c])
ft2 = relay.FuncType([b, c], relay.TupleType([a]), [a, b, c])
solver.Unify(ft1, ft2)
测试在布局转换过程中整数的兼容性#
x = relay.var("data", shape=(2, 3, 48, 48), dtype="float32")
conv_out = relay.nn.conv2d(
x,
relay.var("weight", shape=(1, 3, 1, 1), dtype="float32"),
strides=[47, 47],
channels=1,
kernel_size=[1, 1],
)
bias_out = relay.nn.bias_add(conv_out, relay.var("bias"))
broadcast_out = relay.op.broadcast_to(bias_out, relay.const([2, 1, 2, 2], dtype="int64"))
y = relay.add(bias_out, broadcast_out)
mod, _ = testing.create_workload(y)
with tvm.transform.PassContext(opt_level=3):
with tvm.target.Target("llvm"):
mod = relay.transform.CanonicalizeOps()(mod)
mod = relay.transform.AlterOpLayout()(mod)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
首先定义了一个变量
x
,它是一个形状为 (2, 3, 48, 48) 的浮点数张量。然后使用
relay.nn.conv2d
函数进行卷积操作,输入是x
,权重的形状为 (1, 3, 1, 1),步长为 [47, 47],通道数为 1,卷积核大小为 [1, 1]。接下来使用
relay.nn.bias_add
函数将偏置项添加到卷积输出上,得到bias_out
。使用
relay.op.broadcast_to
函数将bias_out
广播到形状为 [2, 1, 2, 2] 的张量,得到broadcast_out
。最后,将
bias_out
和broadcast_out
相加,得到结果y
。创建一个工作负载
mod
,其中包含计算图y
。使用 TVM 的优化级别为 3 的上下文,将目标设置为 “llvm”。
对计算图
mod
进行规范化操作,然后进行布局转换操作。
这段代码的目的是测试在布局转换过程中整数的兼容性,确保在不同类型的数据之间进行操作时不会引发错误或异常。