解读 _test_type_solver#

参考:tvm/tests/python/relay/test_type_solver.py & tvm/src/relay/analysis/type_solver.h

import testing
import tvm
from tvm import relay
from tvm.relay import testing

import numpy as np
def make_rel(name, args, num_inputs=None, attrs=None):
    func = tvm.ir.EnvFunc.get("tvm.relay.type_relation." + name)
    if num_inputs is None:
        num_inputs = len(args) - 1
    return relay.ty.TypeRelation(func, args, num_inputs, attrs)
relay.ty.TypeRelation??
Init signature: relay.ty.TypeRelation(func, args, num_inputs, attrs)
Source:        
@tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint):
    """User defined type relation, it is an input-output relation on types.

    TypeRelation is more generalized than TypeCall as it allows inference
     of both inputs and outputs.

    Parameters
    ----------
    func : EnvFunc
        User defined relation function.

    args : [tvm.ir.Type]
        List of types to the func.

    num_inputs : int
        Number of input arguments in args,
        this act as a hint for type inference.

    attrs : Attrs
        The attribute attached to the relation information

    Returns
    -------
    type_relation : tvm.ir.TypeRelation
        The type relation.
    """

    def __init__(self, func, args, num_inputs, attrs):
        self.__init_handle_by_constructor__(_ffi_api.TypeRelation, func, args, num_inputs, attrs)
File:           /media/pc/data/lxw/ai/tvm/python/tvm/ir/type_relation.py
Type:           type
Subclasses:     

TypeRelation 类,继承自 TypeConstraint。它表示用户定义的类型关系,是类型上的输入输出关系。

TypeRelationTypeCall 更通用,因为它允许推断输入和输出。

参数:

  • func: EnvFunc,用户定义的关系函数。

  • args: [tvm.ir.Type],要传递给 func 的类型列表。

  • num_inputs: intargs 中的输入参数数量,这作为类型推断的提示。

  • attrs: Attrs,附加到关系信息的属性。

返回值:

  • type_relation: tvm.ir.TypeRelation,类型关系。

def make_solver():
    solver = relay.analysis._ffi_api._test_type_solver()
    solver.Solve = solver("Solve")
    solver.Unify = solver("Unify")
    solver.Resolve = solver("Resolve")
    solver.AddConstraint = solver("AddConstraint")

    def gen_type(name, args, out=None):
        out = out if out else relay.ty.IncompleteType()
        solver.AddConstraint(make_rel(name, args + [out]))
        return out

    solver.gen_type = gen_type
    return solver
/*!
 * \brief Interface of type solver used in type inference.
 *
 * TypeSolver works on a list of constraints among incomplete types.
 * The user will populate the constraints by AddConstraint and Assign.
 * Then we can call Solve to trying to resolve the unknown.
 *
 * This can be viewed as "type program(computational graph)" of types, where
 * the type constraint are operators of the graph and the incomplete
 * types are intermediate value of the graph.
 * If all the input types are concretely known, we should be able to
 * just run a forward pass on the "type program" to get all the types.
 *
 * The list of constraints representation means we are storing it as a bipartite
 * graph instead of a DAG. This is because some constraints might go both direction.
 * TypeSolver could take advantage of bidirectional constraints to deduce input
 * value given output ones. Never-the-less, we should keep in mind that
 * there is a "forward direction" that the TypeSolver should take advantage of.
 */
class TypeSolver {
 public:
  TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx);
  ~TypeSolver();
  /*!
   * \brief Add a type constraint to the solver.
   * \param constraint The constraint to be added.
   * \param location The location at which the constraint was incurred.
   */
  void AddConstraint(const TypeConstraint& constraint, const Span& span);
  /*!
   * \brief Resolve type to the solution type in the solver.
   * \param type The type to be resolved.
   * \return The resolved type.
   */
  Type Resolve(const Type& type);
  /*!
   * \brief Start to solve the types using the current known information.
   * \return Whether all the incomplete types has been fully resolved.
   */
  bool Solve();
  /*!
   * \brief Unify lhs and rhs.
   * \param lhs The left operand.
   * \param rhs The right operand
   * \param location The location at which the unification problem arose.
   */
  Type Unify(const Type& lhs, const Type& rhs, const Span& span, bool assign_lhs = true,
             bool assign_rhs = true);
  /*!
   * \brief Report a diagnostic.
   * \param diag The diagnostic to report.
   */
  void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }

这段代码是类型推断器(TypeSolver)的接口定义。类型推断器用于解决类型之间的约束关系,以确定未知类型的具体值。

该接口包含以下方法:

  • AddConstraint: 向求解器中添加类型约束。参数包括要添加的约束和约束发生的位置。

  • Resolve: 将给定的类型解析为解决方案类型。参数是要解析的类型,返回解析后的类型。

  • Solve: 使用当前已知信息开始解决类型。返回是否所有不完整类型都已完全解析。

  • Unify: 统一 lhs 和 rhs。参数包括左操作数、右操作数以及问题出现的位置。还可以指定是否分配给 lhs 和 rhs。

  • Emit: 报告诊断信息。参数是要报告的诊断信息。

广播的类型推断(solver)#

solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
t1 = relay.ty.TensorType((10, 1), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t2 = solver.gen_type("Broadcast", [t0, t1])
t3 = solver.gen_type("Identity", [t2])
t4 = solver.gen_type("Broadcast", [t3, tc])
assert solver.Solve()
assert solver.Resolve(t2) == relay.ty.TensorType((10, 20), "float32")
assert solver.Resolve(t4) == relay.ty.TensorType((10, 10, 20), "float32")

其他类型推断(solver)#

def test_backward_solving():
    solver = make_solver()
    t0 = relay.ty.TensorType((10, 20), "float32")
    tc = relay.ty.TensorType((10, 1, 1), "float32")
    t1 = relay.ty.IncompleteType()
    t3 = solver.gen_type("Broadcast", [t0, t1])
    t2 = solver.gen_type("Identity", [t1], out=tc)
    assert solver.Solve()
    assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32")


def test_unify_tuple():
    solver = make_solver()
    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.TensorType((10, 20), "float32")

    tup1 = relay.ty.TupleType([t1, t2])
    tup2 = relay.ty.TupleType([t3, t3])

    unified = solver.Unify(tup1, tup2)
    assert unified == tup2


def test_unify_global_type_var():
    # should only be able to unify if they're the same
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv")
    unified = solver.Unify(gtv, gtv)
    assert unified == gtv


def test_unify_typecall():
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv")

    # yeah, typecalls are shaped like tuples so the same
    # tests work out
    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.TensorType((10, 20), "float32")

    tc1 = relay.ty.TypeCall(gtv, [t1, t2])
    tc2 = relay.ty.TypeCall(gtv, [t3, t3])
    unified = solver.Unify(tc1, tc2)
    assert unified == tc2


def test_unify_functype():
    solver = make_solver()
    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.IncompleteType()

    unit = relay.ty.TupleType([])
    tensor1 = relay.ty.TensorType((10, 20), "float32")
    tensor2 = relay.ty.TensorType((10,), "float32")

    ft1 = relay.ty.FuncType([t1, t2], t3)
    ft2 = relay.ty.FuncType([tensor1, tensor2], unit)

    unified = solver.Unify(ft1, ft2)
    assert unified == ft2


def test_recursive_unify():
    solver = make_solver()
    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.IncompleteType()

    tensor1 = relay.ty.TensorType((10, 10, 20), "float32")
    tensor2 = relay.ty.TensorType((10, 20), "float32")
    tensor3 = relay.ty.TensorType((10,), "float32")

    tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2])
    tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2])

    ft1 = relay.ty.FuncType([tup1, t3], t3)
    ft2 = relay.ty.FuncType([tup2, tensor3], tensor3)

    unified = solver.Unify(ft1, ft2)
    assert unified == ft2


def test_unify_vars_under_tuples():
    solver = make_solver()
    t1 = relay.ty.IncompleteType()

    tup1 = relay.ty.TupleType([t1, t1])
    unified = solver.Unify(tup1, tup1)
    assert unified == tup1

    t2 = relay.ty.IncompleteType()
    tup2 = relay.ty.TupleType([t2, t2])

    tup3 = relay.ty.TupleType([t1, t2])
    tup4 = relay.ty.TupleType([t2, t1])
    unified = solver.Unify(tup3, tup4)
    assert unified == tup1 or unified == tup2


def test_binding_over_typevars():
    solver = make_solver()

    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()

    a = relay.ty.TypeVar("a")
    b = relay.ty.TypeVar("b")
    c = relay.ty.TypeVar("c")
    d = relay.ty.TypeVar("d")

    ft1 = relay.ty.FuncType([t1], t2, [c, d])
    ft2 = relay.ty.FuncType([a], b, [a, b])
    unified = solver.Unify(ft1, ft2)
    assert unified == solver.Resolve(ft1)


def test_recursive_backward_solving():
    solver = make_solver()

    tensor1 = relay.ty.TensorType((10, 20), "float32")
    tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
    tensor3 = relay.ty.TensorType((10,), "float32")

    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.IncompleteType()

    tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
    tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3])
    solver.gen_type("Identity", [tup1], out=tup2)

    assert solver.Solve()
    assert solver.Resolve(tup2) == tup1


def test_backward_solving_after_child_update():
    solver = make_solver()

    tensor1 = relay.ty.TensorType((10, 20), "float32")
    tensor2 = relay.ty.TensorType((10, 1, 1), "float32")

    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.IncompleteType()

    tup1 = relay.ty.TupleType([t1, t2])
    tup2 = relay.ty.TupleType([t1, t3])

    tup_concrete = relay.ty.TupleType([tensor1, tensor2])

    t4 = solver.gen_type("Identity", [tup1])
    t5 = solver.gen_type("Identity", [tup2])

    solver.gen_type("Identity", [t4], out=t5)
    assert solver.Solve()
    assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2
    assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2
    assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2

    # updating the variables *inside* tup1 and tup2 should update t4 and t5
    solver.gen_type("Identity", [t1], out=tensor1)
    solver.gen_type("Identity", [t2], out=tensor2)
    assert solver.Solve()
    assert solver.Resolve(t4) == tup_concrete
    assert solver.Resolve(t5) == tup_concrete


def test_unify_quantified_funcs():
    solver = make_solver()
    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
    ft1 = relay.FuncType([a, b], c, [a, b, c])
    ft2 = relay.FuncType([a, a], a, [a])
    unified = solver.Unify(ft1, ft2)
    assert unified == ft2

    ft3 = relay.FuncType([a], a, [a])
    ft4 = relay.FuncType([b], c, [b, c])
    unified = solver.Unify(ft3, ft4)
    assert unified == ft3


def test_unify_quantified_func_and_concrete():
    solver = make_solver()
    a, b = relay.TypeVar("a"), relay.TypeVar("b")
    ft1 = relay.FuncType([a], b, [a, b])
    ft2 = relay.FuncType([b], relay.TupleType([]), [b])
    unified = solver.Unify(ft1, ft2)
    assert unified == ft2


def test_unify_quantified_funcs_nesting():
    solver = make_solver()
    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")

    ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
    ft2 = relay.FuncType([a, relay.TupleType([a, a])], relay.TupleType([a, a, a]), [a])
    unified = solver.Unify(ft1, ft2)
    assert unified == ft2


def test_unify_quantified_funcs_var_order():
    solver = make_solver()
    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")

    ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
    ft2 = relay.FuncType([a, relay.TupleType([a, c])], relay.TupleType([a, a, c]), [a, c])
    # unified = solver.Unify(ft1, ft2) # crashes here but it shouldn't
    # assert unified == ft2

不兼容的类型推断#

import pytest

@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_tuple_unification():
    solver = make_solver()
    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()

    tensor1 = relay.ty.TensorType((1, 2, 3), "float32")
    tensor2 = relay.ty.TensorType((2, 3), "float32")
    tensor3 = relay.ty.TensorType((3,), "float32")

    tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2])
    tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
    solver.Unify(tup1, tup2)


@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_bad_recursive_unification():
    solver = make_solver()
    t1 = relay.ty.IncompleteType()
    solver.Unify(t1, relay.ty.TupleType([t1, t1]))


@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_unify_invalid_global_typevars():
    solver = make_solver()
    gtv1 = relay.GlobalTypeVar("gtv1")
    gtv2 = relay.GlobalTypeVar("gtv2")
    solver.Unify(gtv1, gtv2)


@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_typecall_var_unification():
    solver = make_solver()
    gtv1 = relay.GlobalTypeVar("gtv1")
    gtv2 = relay.GlobalTypeVar("gtv2")

    t1 = relay.IncompleteType()
    t2 = relay.IncompleteType()

    tc1 = relay.TypeCall(gtv1, [t1])
    tc2 = relay.TypeCall(gtv2, [t2])
    solver.Unify(tc1, tc2)


@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_typecall_args_unification():
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv1")
    t1 = relay.IncompleteType()
    t2 = relay.IncompleteType()

    tensor1 = relay.TensorType((1, 2, 3), "float32")
    tensor2 = relay.TensorType((2, 3), "float32")
    tensor3 = relay.TensorType((3,), "float32")

    tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2])
    tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3])
    solver.Unify(tc1, tc2)


@pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
def test_incompatible_quantified_func_unification():
    solver = make_solver()
    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")

    ft1 = relay.FuncType([a, b], c, [a, b, c])
    ft2 = relay.FuncType([b, c], relay.TupleType([a]), [a, b, c])
    solver.Unify(ft1, ft2)

测试在布局转换过程中整数的兼容性#

x = relay.var("data", shape=(2, 3, 48, 48), dtype="float32")
conv_out = relay.nn.conv2d(
    x,
    relay.var("weight", shape=(1, 3, 1, 1), dtype="float32"),
    strides=[47, 47],
    channels=1,
    kernel_size=[1, 1],
)
bias_out = relay.nn.bias_add(conv_out, relay.var("bias"))
broadcast_out = relay.op.broadcast_to(bias_out, relay.const([2, 1, 2, 2], dtype="int64"))
y = relay.add(bias_out, broadcast_out)

mod, _ = testing.create_workload(y)
with tvm.transform.PassContext(opt_level=3):
    with tvm.target.Target("llvm"):
        mod = relay.transform.CanonicalizeOps()(mod)
        mod = relay.transform.AlterOpLayout()(mod)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
  1. 首先定义了一个变量 x,它是一个形状为 (2, 3, 48, 48) 的浮点数张量。

  2. 然后使用 relay.nn.conv2d 函数进行卷积操作,输入是 x,权重的形状为 (1, 3, 1, 1),步长为 [47, 47],通道数为 1,卷积核大小为 [1, 1]。

  3. 接下来使用 relay.nn.bias_add 函数将偏置项添加到卷积输出上,得到 bias_out

  4. 使用 relay.op.broadcast_to 函数将 bias_out 广播到形状为 [2, 1, 2, 2] 的张量,得到 broadcast_out

  5. 最后,将 bias_outbroadcast_out 相加,得到结果 y

  6. 创建一个工作负载 mod,其中包含计算图 y

  7. 使用 TVM 的优化级别为 3 的上下文,将目标设置为 "llvm"。

  8. 对计算图 mod 进行规范化操作,然后进行布局转换操作。

这段代码的目的是测试在布局转换过程中整数的兼容性,确保在不同类型的数据之间进行操作时不会引发错误或异常。