节点模式#
import set_env
import numpy as np
import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass
# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1
ExprPattern
#
ep = is_expr(relay.var("x", shape=(4, 1)))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, relay.Var)
VarPattern
#
v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
ConstantPattern
#
c = is_constant()
assert isinstance(c, ConstantPattern)
WildcardPattern
#
wc = wildcard()
assert isinstance(wc, WildcardPattern)
CallPattern
#
wc1 = wildcard()
wc2 = wildcard()
c = is_op("add")(wc1, wc2)
assert isinstance(c, CallPattern)
assert isinstance(c.args[0], WildcardPattern)
assert isinstance(c.args[1], WildcardPattern)
FunctionPattern
#
wc1 = wildcard()
wc2 = wildcard()
c = is_op("add")(wc1, wc2)
f = FunctionPattern([wc1, wc2], c)
assert isinstance(f, FunctionPattern)
assert isinstance(f.params[0], WildcardPattern)
assert isinstance(f.params[1], WildcardPattern)
assert isinstance(f.body, CallPattern)
assert isinstance(f.body.args[0], WildcardPattern)
assert isinstance(f.body.args[1], WildcardPattern)
TuplePattern
#
wc1 = wildcard()
wc2 = wildcard()
t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], WildcardPattern)
TupleGetItemPattern
#
wc1 = wildcard()
wc2 = wildcard()
t = is_tuple([wc1, wc2])
tgi = is_tuple_get_item(t, 1)
assert isinstance(tgi, TupleGetItemPattern)
assert isinstance(tgi.tuple, TuplePattern)
assert isinstance(tgi.tuple.fields[0], WildcardPattern)
assert isinstance(tgi.tuple.fields[1], WildcardPattern)
AltPattern
#
is_add_or_sub = is_op("add") | is_op("subtract")
assert isinstance(is_add_or_sub, AltPattern)
TypePattern
#
ttype = relay.TensorType((10, 10), "float32")
ty_pat = has_type(ttype)
assert isinstance(ty_pat, TypePattern)
assert ty_pat.type == ttype
DataTypePattern
#
dtype = "float16"
pattern = has_dtype(dtype)
assert isinstance(pattern, DataTypePattern)
assert pattern.dtype == dtype
ShapePattern
#
shape = [10, 10]
pattern = has_shape(shape)
assert isinstance(pattern, ShapePattern)
assert tvm.ir.structural_equal(pattern.shape, shape)
AttrPattern
#
op = is_op("add").has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert op.attrs["TOpPattern"] == K_ELEMWISE
IfPattern
#
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)
assert isinstance(pat, IfPattern)
assert isinstance(pat.cond, CallPattern)
assert isinstance(pat.true_branch, VarPattern)
assert isinstance(pat.false_branch, VarPattern)
LetPattern
#
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)
assert isinstance(pat, LetPattern)
assert isinstance(pat.var, VarPattern)
assert isinstance(pat.value, CallPattern)
assert isinstance(pat.body, VarPattern)