DPL 模式匹配#
%%file demo.py
import tvm
from tvm.script import relax as R
from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
k = T.int32()
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
C = T.match_buffer(z, (32, 32))
for i0, j0, k0 in T.grid(32, 32, 32):
with T.block():
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
with T.init():
C[i, j] = 0.0
C[i, j] += A[i, k] * B[j, k]
@T.prim_func
def tir_relu(x: T.handle, y: T.handle):
T.func_attr({"global_symbol": "tir_relu"})
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
for i, j in T.grid(32, 32):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.max(A[vi, vj], 0.0)
@T.prim_func
def tir_zeros(x: T.handle, n: T.int64):
T.func_attr({"global_symbol": "tir_zeros"})
A = T.match_buffer(x, [n])
for i in range(n):
with T.block():
vi = T.axis.remap("S", [i])
A[vi] = 1.0
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_tir(
cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
)
gv = (lv1, lv2)
R.output(gv)
return gv
Overwriting demo.py
from demo import Module
main_fn = Module["main"]
bindings = main_fn.body.blocks[0].bindings
Module.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def tir_matmul(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32"), C: T.Buffer((32, 32), "float32")):
# with T.block("root"):
for i0, j0, k0 in T.grid(32, 32, 32):
with T.block(""):
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
T.reads(A[i, k], B[j, k])
T.writes(C[i, j])
with T.init():
C[i, j] = T.float32(0.0)
C[i, j] = C[i, j] + A[i, k] * B[j, k]
@T.prim_func
def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")):
# with T.block("root"):
for i, j in T.grid(32, 32):
with T.block(""):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = T.max(A[vi, vj], T.float32(0.0))
@T.prim_func
def tir_zeros(x: T.handle, n: T.int64):
A = T.match_buffer(x, (n,))
# with T.block("root"):
for i in range(n):
with T.block(""):
vi = T.axis.spatial(n, i)
T.reads()
T.writes(A[vi])
A[vi] = T.float32(1.0)
@R.function
def main(x: R.Tensor((32, 32), dtype="float32"), w: R.Tensor((32, 32), dtype="float32")) -> R.Tuple:
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), out_sinfo=R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0,), out_sinfo=R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_tir(cls.tir_zeros, R.tuple(), out_sinfo=R.Tensor((32,), dtype="float32"), tir_vars=R.shape([32]))
gv: R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tensor((32,), dtype="float32")) = lv1, lv2
R.output(gv)
return gv
bindings
[x: R.Tensor((32, 32), dtype="float32")
w: R.Tensor((32, 32), dtype="float32")
lv0 = R.call_tir(tir_matmul, (x, w), out_sinfo=R.Tensor((32, 32), dtype="float32")), lv0: R.Tensor((32, 32), dtype="float32")
lv1 = R.call_tir(tir_relu, (lv0,), out_sinfo=R.Tensor((32, 32), dtype="float32")), lv2 = R.call_tir(tir_zeros, R.tuple(), out_sinfo=R.Tensor((32,), dtype="float32"), tir_vars=R.shape([32])), lv1: R.Tensor((32, 32), dtype="float32")
lv2: R.Tensor((32,), dtype="float32")
gv: R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tensor((32,), dtype="float32")) = lv1, lv2]
from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *
节点级匹配#
测试表达式模式:
ep = is_expr(rx.Var("x"))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, rx.Var)
测试变量模式:
v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
assert v.match(rx.Var("x"))
assert is_var().match(rx.Var("x"))
assert is_var().match(rx.DataflowVar("x")) # DataflowVar 也是 Var
assert not v.match(rx.GlobalVar("x"))
v = is_dfv("x")
assert isinstance(v, DataflowVarPattern)
assert v.name == "x"
assert v.match(rx.DataflowVar("x"))
assert not v.match(rx.GlobalVar("x"))
assert is_dfv().match(bindings[0].var)
assert is_gv("x").match(rx.GlobalVar("x"))
# TODO: 由于与 PyTorch 的符号冲突,正则表达式功能暂时被禁用
# assert is_gv("x.*").match(rx.GlobalVar("x_2"))
assert is_gv().match(rx.GlobalVar("x"))
assert not is_gv("x").match(rx.GlobalVar("y"))
assert not is_gv("x").match(rx.Var("x"))
匹配常量:
c = is_const()
assert isinstance(c, ConstantPattern)
assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]))
模糊匹配:
wc = wildcard()
assert isinstance(wc, WildcardPattern)
assert wc.match(rx.Var("x"))
回调匹配:
wc1 = wildcard()
wc2 = wildcard()
c = is_op("relax.add")(wc1, wc2)
assert isinstance(c, CallPattern)
assert isinstance(c.args[0], WildcardPattern)
assert isinstance(c.args[1], WildcardPattern)
assert c.match(rx.op.add(rx.Var("x"), rx.Var("y")))
匹配函数:
wc1 = wildcard()
wc2 = wildcard()
f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2))
assert isinstance(f, FunctionPattern)
assert isinstance(f.params[0], WildcardPattern)
assert isinstance(f.params[1], WildcardPattern)
assert isinstance(f.body, CallPattern)
assert isinstance(f.body.args[0], WildcardPattern)
assert isinstance(f.body.args[1], WildcardPattern)
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")))
assert not f.match(
rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32"))
)
元组匹配:
wc1 = wildcard()
wc2 = is_dfv()
t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], DataflowVarPattern)
assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]))
assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")]))
assert not t.match(rx.Tuple([]))
assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
# Negative index is also allowed
assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
# None means any index.
assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
import pytest
with pytest.raises(IndexError):
t[2] # index cannot be greater than or equal to the tuple size.
t = is_tuple([is_const(), is_dfv()], unordered=True)
assert isinstance(t, UnorderedTuplePattern)
assert isinstance(t.fields[0], ConstantPattern)
assert isinstance(t.fields[1], DataflowVarPattern)
assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")]))
assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])]))
assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")]))
assert not t.match(rx.Tuple([]))
assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
)
assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
)
匹配 or
:
dfv_or_gv = is_dfv("x") | is_gv("x")
assert isinstance(dfv_or_gv, OrPattern)
assert dfv_or_gv.match(rx.DataflowVar("x"))
assert dfv_or_gv.match(rx.GlobalVar("x"))
assert not dfv_or_gv.match(rx.Var("x"))
assert not dfv_or_gv.match(rx.DataflowVar("y"))
assert not dfv_or_gv.match(rx.GlobalVar("y"))
匹配 and
:
# float[2, 3, 3]
f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32")
assert isinstance(f32_233, AndPattern)
assert f32_233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
assert not f32_233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
assert not f32_233.match(rx.Var("x", R.Tensor("float32", ndim=3)))
匹配 not
:
no_shape233 = ~wildcard().has_shape((2, 3, 3))
assert isinstance(no_shape233, NotPattern)
assert no_shape233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
匹配 type
:
assert wildcard().has_type(rx.DynTensorType(2, "float32")).match(bindings[0].var)
匹配 dtype
:
dtype = "float16"
pattern = has_dtype(dtype)
assert isinstance(pattern, DataTypePattern)
assert pattern.dtype == dtype
assert has_dtype("float32").match(bindings[0].var)
匹配 shape
:
shape = [32, 32]
pattern = wildcard().has_shape(shape)
assert isinstance(pattern, ShapePattern)
tvm.ir.structural_equal(pattern.shape, shape)
assert pattern.match(bindings[0].var)
assert wildcard().has_shape([32, 32]).match(bindings[0].var)
n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64")
symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32"))
assert wildcard().has_shape([n, m, n + m]).match(symsh_var)
assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative.
assert not wildcard().has_shape([1, 2, 3]).match(symsh_var)
assert not wildcard().has_shape([m, n, n + m]).match(symsh_var)
匹配:PrimArray
备注
is_shape
和 has_shape
的区别在于:
is_shape
直接匹配形状(例如,作为参数);has_shape
匹配张量并对该张量的形状做出假设。
pattern = is_shape([32, 32])
assert pattern[0] == 32
assert pattern[1] == 32
assert isinstance(pattern, PrimArrPattern)
assert pattern.match(rx.get_shape_of(bindings[0].var))
n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64")
symbolic_shape = rx.ShapeExpr([n, m, n + m])
assert is_shape([n, m, n + m]).match(symbolic_shape)
assert not is_shape([n, m, n * m]).match(symbolic_shape)
匹配外部函数:
pattern = ExternFuncPattern("test.blockbuilder.nop")
assert pattern.match(rx.ExternFunc("test.blockbuilder.nop"))
匹配算子属性:
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
conv2d = relay.nn.conv2d(x, y, kernel_size=(3, 3))
xp = is_var("x")
yp = is_var("y")
# TODO(@yuchen): reenable the assert after figuring out why it fails
# assert is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [3, 3]}).match(conv2d)
assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [4, 3]}).match(conv2d)
assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size_": [3, 3]}).match(conv2d)
匹配 call 属性:
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))
annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"})
xp = is_var("x")
yp = is_var("y")
root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp))
assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match(
annotated_fn
)
assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn)
assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn)
assert root_pattern.has_attr({}).match(annotated_fn)
匹配 is_call_tir
:
lv1_val = bindings[1].value
lv2_val = bindings[2].value
var2val = get_var2val(Module["main"])
assert is_call_tir("tir_relu").match(lv1_val)
assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)
匹配 call_packed
:
@R.function(pure=False)
def simple_call_packed(
x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
) -> R.Tensor:
gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
return gv0
def test_varg_default_wildcard():
expr = simple_call_packed.body.blocks[0].bindings[0].value
yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard())
yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True)
no_pattern = ExternFuncPattern("test.vm.mul")(wildcard())
assert yes_pattern_explicit.match(expr)
assert yes_pattern_implicit.match(expr)
assert not no_pattern.match(expr)
def test_simple_call_packed():
expr = simple_call_packed.body.blocks[0].bindings[0].value
assert is_call_packed("test.vm.mul").match(expr)
assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr)
图级匹配#
def test_simple_used_by():
with PatternContext() as ctx:
n0 = is_var("x") # x is a free var (fn arg)
n1 = wildcard()
n0 ^ n1
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == main_fn.params[0]
assert matched[n1] == dfb.bindings[0].var
def test_simple_call_tir_edge():
with PatternContext() as ctx:
n0 = is_call_tir("tir_matmul")
n1 = is_call_tir("tir_relu")
n0.used_by(n1)
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == dfb.bindings[0].var
assert matched[n1] == dfb.bindings[1].var
def test_simple_oub():
with PatternContext() as ctx:
n0 = is_call_tir("tir_matmul")
n1 = is_call_tir("tir_relu")
n0 >> n1
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == dfb.bindings[0].var
assert matched[n1] == dfb.bindings[1].var
def test_counter_syntax_match():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_impossible")
n0 >> n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_impossible")
n0 ^ n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
@tvm.script.ir_module
class Diamond:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# matmul
# / \
# relu sigmoid
# \ /
# add
lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32"))
R.output(lv3)
return lv3
def test_diamond():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
n0 ^ n1
n0 ^ n2
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
# simplify it with fork_to
with PatternContext() as ctx:
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(n1, n2)
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_diamond_counter_oub():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
n0 >> n1
n0 >> n2
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
@tvm.script.ir_module
class SmallDiamond:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu
# / \
# \ /
# add
lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
@tvm.script.ir_module
class SmallParallel:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu relu
# \ /
# add
lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32"))
R.output(lv2)
return lv2
def test_distinguish_diamond_and_parallel():
# relay pattern lang cannot distinguish the two cases above.
diamond = SmallDiamond["main"].body.blocks[0]
parallel = SmallParallel["main"].body.blocks[0]
with PatternContext() as ctx:
# describe a diamond pattern
fork = is_call_dps_packed("my_relu")
join = is_call_dps_packed("my_add")
fork.only_used_by(join, index=0)
fork.only_used_by(join, index=1)
assert ctx.match_dfb(diamond)
assert not ctx.match_dfb(parallel)
with PatternContext() as ctx:
# describe a parallel pattern
join = is_call_dps_packed("my_add")
# Due to one-one matching:
# is_call_dps_packed("my_relu") creates the 1st relu
is_call_dps_packed("my_relu") >> join
# is_call_dps_packed("my_relu")
# creates the another different relu (obj address is different)
is_call_dps_packed("my_relu") >> join
assert ctx.match_dfb(parallel)
assert not ctx.match_dfb(diamond)
@tvm.script.ir_module
class CBRx2:
@R.function
def main(
x: R.Tensor((32, 32), "float32"),
w0: R.Tensor((1, 1), "float32"),
bias0: R.Tensor((32, 32), "float32"),
w1: R.Tensor((1, 1), "float32"),
bias1: R.Tensor((32, 32), "float32"),
) -> R.Tensor:
# R.TensorRT's CBR Optimization Pattern
# input
# / \
# cbr0 cbr1
# \ /
# concat
with R.dataflow():
lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32"))
lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32"))
lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32"))
lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32"))
lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32"))
R.output(lv6)
return lv6
def test_nested_context():
dfb = CBRx2["main"].body.blocks[0]
with PatternContext() as ctx0:
(
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
with PatternContext() as ctx1:
is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") # pattern to miss
with PatternContext() as ctx2:
is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu")
assert ctx2.match_dfb(dfb)
assert PatternContext.current() == ctx2
assert not ctx1.match_dfb(dfb)
assert PatternContext.current() == ctx1
assert ctx0.match_dfb(dfb)
assert PatternContext.current() == ctx0
def test_two_cbr():
with PatternContext() as ctx:
cbr0 = (
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
assert cbr0.patterns[0] != cbr1.patterns[0]
assert cbr0.patterns[1] != cbr1.patterns[1]
assert cbr0.patterns[2] != cbr1.patterns[2]
is_var("x").fork_to(cbr0, cbr1)
dfb = CBRx2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
# Deny the pattern
cbr0 = (
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
# input has no fork at y.
is_var("y").fork_to(cbr0, cbr1)
dfb = CBRx2["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
def test_two_matmul():
# Same as Figure 2(a) in TASO paper.
@tvm.script.ir_module
class MatMul2:
@R.function
def main(
a: R.Tensor((32, 16), "float32"),
b: R.Tensor((16, 48), "float32"),
c: R.Tensor((48, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32"))
lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
with PatternContext() as ctx:
is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape(
[32, 32]
)
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
# Three MatMul cannot match
assert not ctx.match_dfb(dfb)
def test_concat_mm_split():
# Same as Figure 2(b) in TASO paper.
@tvm.script.ir_module
class CMS:
@R.function
def main(
a: R.Tensor((32, 32), "float32"),
b: R.Tensor((16, 32), "float32"),
c: R.Tensor((16, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed(
"my_split",
(lv1,),
[R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")],
)
lv3 = R.TupleGetItem(lv2, 0)
lv4 = R.TupleGetItem(lv2, 1)
lv5 = R.add(lv3, lv4)
R.output(lv5)
return lv5
with PatternContext() as ctx:
(
is_call_dps_packed("my_concat")
>> is_call_dps_packed("my_matmul")
>> is_call_dps_packed("my_split")
)
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
split = is_call_dps_packed("my_split")
lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32])
lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32])
split.fork_to(lv3, lv4)
add = is_op("relax.add")(lv3, lv4)
# TODO(@ganler): simplify this through implicit graph pattern.
lv3 >> add
lv4 >> add
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_self_attention():
# The example comes from.
# https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/
@tvm.script.ir_module
class SelfAttention:
@R.function
def main(
x: R.Tensor(("b", "s", "n", "h"), "float32"),
wq: R.Tensor(("h", "h"), "float32"),
wk: R.Tensor(("h", "h"), "float32"),
wv: R.Tensor(("h", "h"), "float32"),
) -> R.Tensor:
b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64()
with R.dataflow():
fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32"))
tpq = R.call_dps_packed(
"my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")
)
fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32"))
tpk = R.call_dps_packed(
"my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")
)
mul = R.multiply(tpq, tpk)
scale = R.multiply(mul, R.const(1.1, "float32"))
softmax = R.call_dps_packed(
"softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")
)
fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32"))
tpv = R.call_dps_packed(
"my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")
)
out = R.multiply(softmax, tpv)
R.output(out)
return out
with PatternContext() as ctx:
fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose")
fc_trans_k = fc_trans_q.dup()
fc_trans_v = fc_trans_q.dup()
is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v)
dfb = SelfAttention["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_nested_diamond():
@tvm.script.ir_module
class DiamondInDiamond:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# matmul0 matmul1
# / \ / \
# sigmoid2 add4 sigmoid3
# \ / \ /
# add5 add6
# \ /
# add7
lv0 = R.call_dps_packed(
"extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
)
lv1 = R.call_dps_packed(
"extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
)
lv2 = R.call_dps_packed(
"extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")
)
lv3 = R.call_dps_packed(
"extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")
)
lv4 = R.call_dps_packed(
"extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")
)
lv5 = R.call_dps_packed(
"extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")
)
lv6 = R.call_dps_packed(
"extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")
)
lv7 = R.call_dps_packed(
"extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")
)
R.output(lv7)
return lv7
# match matmul0 diamond
with PatternContext() as ctx:
sigmoid2 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 ^ add5
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# counter case: mis-match matmul0 diamond
with PatternContext() as ctx:
sigmoid2 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 >> add5 # not only-used-by relation
assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match matmul1 diamond
with PatternContext() as ctx:
sigmoid3 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4)
add6 = is_call_dps_packed("extern_add")
sigmoid3 >> add6
add4 ^ add6
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match add-4-5-6-7
with PatternContext() as ctx:
add5, add6, add7 = (
is_call_dps_packed("extern_add"),
is_call_dps_packed("extern_add"),
is_call_dps_packed("extern_add"),
)
is_call_dps_packed("extern_add").fork_to(add5, add6) # add4
add5 >> add7
add6 >> add7
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
def test_incremental_solving():
@R.function
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu -> sigmoid -> neg
lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32"))
R.output(lv2)
return lv2
relu = is_call_dps_packed("extern_relu")
sigmoid = is_call_dps_packed("extern_sigmoid")
neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid
with PatternContext(incremental=True) as ctx1:
# because we are doing incremental solving
# relu >> sigmoid is still a constraint in this context.
# that said the total constraint is:
# relu >> sigmoid >> neg
sigmoid >> neg
assert ctx1.match_dfb(simple_chain.body.blocks[0])
# match relue -> sigmoid
assert ctx0.match_dfb(simple_chain.body.blocks[0])
def test_incremental_solving_counter():
@R.function
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# sigmoid -> neg
lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
relu = is_call_dps_packed("extern_relu")
sigmoid = is_call_dps_packed("extern_sigmoid")
neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid # cannot match
with PatternContext(incremental=False) as ctx1:
# total constraint: sigmoid >> neg
sigmoid >> neg
assert ctx1.match_dfb(simple_chain.body.blocks[0])
with PatternContext(incremental=True) as ctx1:
# total constraint: relu >> sigmoid >> neg
sigmoid >> neg
assert not ctx1.match_dfb(simple_chain.body.blocks[0])