测试匹配结构信息#
from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *
测试匹配时带有结构信息更新的通配符:
def test_wildcard_with_struct_info_updates_when_matching():
"""A DFPattern may be restricted to a specific StructInfo"""
pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.add(A, B)
R.output(C)
return C
@R.function(private=True)
def expected():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.multiply(A, B)
R.output(C)
return C
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_with_struct_info_is_no_op_when_not_matching():
"""StructInfoPattern requires the StructInfo provided
Here, the pattern would match, expect that the function has
`R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`.
"""
pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
# This R.add has the same shape as the pattern, and will
# be updated.
A = R.zeros([16, 32], "int32")
B = R.ones([16, 32], "int32")
C = R.add(A, B)
R.output(C)
return C
expected = before
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_struct_info_for_unknown_dtype():
"""TensorStructInfo with unknown dtype allows any dtype"""
pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.add(A, B)
D = R.zeros([2, 3], "float32")
E = R.ones([2, 3], "float32")
F = R.add(D, E)
output = (C, F)
R.output(output)
return output
@R.function(private=True)
def expected():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.multiply(A, B)
D = R.zeros([2, 3], "float32")
E = R.ones([2, 3], "float32")
F = R.multiply(D, E)
output = (C, F)
R.output(output)
return output
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_struct_info_with_symbolic_vars():
"""StructInfoPattern may define symbolic vars
This test finds an elementwise `R.add`, while ignoring a
broadcasted `R.add`.
"""
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
pat_lhs = wildcard().has_struct_info(R.Tensor([m, n]))
pat_rhs = wildcard().has_struct_info(R.Tensor([m, n]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
A = R.zeros([64, 128], "int32")
B = R.ones([64, 128], "int32")
C = R.add(A, B)
D = R.zeros([64, 128], "float32")
E = R.ones([1, 128], "float32")
F = R.add(D, E)
output = (C, F)
R.output(output)
return output
@R.function(private=True)
def expected():
with R.dataflow():
A = R.zeros([64, 128], "int32")
B = R.ones([64, 128], "int32")
C = R.multiply(A, B)
D = R.zeros([64, 128], "float32")
E = R.ones([1, 128], "float32")
F = R.add(D, E)
output = (C, F)
R.output(output)
return output
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)