测试无琐碎绑定的重写#
from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *
bind_to_dataflow_var = tvm.testing.parameter(
by_dict={"var-to-var": False, "var-to-dataflow-var": True}
)
备注
rewrite_call 应避免生成琐碎的 "y = x" 绑定
这可能并非在所有情况下都可行,并且遵循与 CanonicalizeBindings 相同的规则。例如,将 relax.Var
绑定到 relax.DataflowVar
可能无法移除,以确保 relax.DataflowVar
仅在 DataflowBlock
内使用。
def test_rewrite_without_trivial_binding(bind_to_dataflow_var):
"""rewrite_call should avoid producing trivial "y = x" bindings
This may not be possible in all cases, and follows the same
rules as CanonicalizeBindings. For example, a `relax.Var` is
bound to a `relax.DataflowVar` may not be removed, to ensure
that the `relax.DataflowVar` is only used within a
`DataflowBlock`.
"""
if bind_to_dataflow_var:
@R.function(private=True)
def before(x: R.Tensor((1024,))):
with R.dataflow():
a = R.add(x, x)
b = R.reshape(a, (1024,))
R.output(b)
return b
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
with R.dataflow():
b = R.add(x, x)
R.output(b)
return b
else:
@R.function(private=True)
def before(x: R.Tensor((1024,))):
a = R.add(x, x)
b = R.reshape(a, (1024,))
return b
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
a = R.add(x, x)
return a
pattern_arg = wildcard()
pattern_shape_expr = wildcard()
pattern = is_op("relax.reshape")(pattern_arg, pattern_shape_expr)
def rewriter(expr, matches):
arg = matches[pattern_arg]
shape_expr = matches[pattern_shape_expr]
if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr):
return arg
else:
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)