测试无琐碎绑定的重写

测试无琐碎绑定的重写#

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *

bind_to_dataflow_var = tvm.testing.parameter(
    by_dict={"var-to-var": False, "var-to-dataflow-var": True}
)

备注

rewrite_call 应避免生成琐碎的 "y = x" 绑定

这可能并非在所有情况下都可行,并且遵循与 CanonicalizeBindings 相同的规则。例如,将 relax.Var 绑定到 relax.DataflowVar 可能无法移除,以确保 relax.DataflowVar 仅在 DataflowBlock 内使用。

def test_rewrite_without_trivial_binding(bind_to_dataflow_var):
    """rewrite_call should avoid producing trivial "y = x" bindings

    This may not be possible in all cases, and follows the same
    rules as CanonicalizeBindings.  For example, a `relax.Var` is
    bound to a `relax.DataflowVar` may not be removed, to ensure
    that the `relax.DataflowVar` is only used within a
    `DataflowBlock`.
    """

    if bind_to_dataflow_var:

        @R.function(private=True)
        def before(x: R.Tensor((1024,))):
            with R.dataflow():
                a = R.add(x, x)
                b = R.reshape(a, (1024,))
                R.output(b)
            return b

        @R.function(private=True)
        def expected(x: R.Tensor((1024,))):
            with R.dataflow():
                b = R.add(x, x)
                R.output(b)
            return b

    else:

        @R.function(private=True)
        def before(x: R.Tensor((1024,))):
            a = R.add(x, x)
            b = R.reshape(a, (1024,))
            return b

        @R.function(private=True)
        def expected(x: R.Tensor((1024,))):
            a = R.add(x, x)
            return a

    pattern_arg = wildcard()
    pattern_shape_expr = wildcard()
    pattern = is_op("relax.reshape")(pattern_arg, pattern_shape_expr)

    def rewriter(expr, matches):
        arg = matches[pattern_arg]
        shape_expr = matches[pattern_shape_expr]

        if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr):
            return arg
        else:
            return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(after, expected)