节点模式#

import set_env
import numpy as np

import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass

# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1

ExprPattern#

ep = is_expr(relay.var("x", shape=(4, 1)))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, relay.Var)

VarPattern#

v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"

ConstantPattern#

c = is_constant()
assert isinstance(c, ConstantPattern)

WildcardPattern#

wc = wildcard()
assert isinstance(wc, WildcardPattern)

CallPattern#

wc1 = wildcard()
wc2 = wildcard()
c = is_op("add")(wc1, wc2)
assert isinstance(c, CallPattern)
assert isinstance(c.args[0], WildcardPattern)
assert isinstance(c.args[1], WildcardPattern)

FunctionPattern#

wc1 = wildcard()
wc2 = wildcard()
c = is_op("add")(wc1, wc2)
f = FunctionPattern([wc1, wc2], c)
assert isinstance(f, FunctionPattern)
assert isinstance(f.params[0], WildcardPattern)
assert isinstance(f.params[1], WildcardPattern)
assert isinstance(f.body, CallPattern)
assert isinstance(f.body.args[0], WildcardPattern)
assert isinstance(f.body.args[1], WildcardPattern)

TuplePattern#

wc1 = wildcard()
wc2 = wildcard()
t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], WildcardPattern)

TupleGetItemPattern#

wc1 = wildcard()
wc2 = wildcard()
t = is_tuple([wc1, wc2])
tgi = is_tuple_get_item(t, 1)
assert isinstance(tgi, TupleGetItemPattern)
assert isinstance(tgi.tuple, TuplePattern)
assert isinstance(tgi.tuple.fields[0], WildcardPattern)
assert isinstance(tgi.tuple.fields[1], WildcardPattern)

AltPattern#

is_add_or_sub = is_op("add") | is_op("subtract")
assert isinstance(is_add_or_sub, AltPattern)

TypePattern#

ttype = relay.TensorType((10, 10), "float32")
ty_pat = has_type(ttype)
assert isinstance(ty_pat, TypePattern)
assert ty_pat.type == ttype

DataTypePattern#

dtype = "float16"
pattern = has_dtype(dtype)
assert isinstance(pattern, DataTypePattern)
assert pattern.dtype == dtype

ShapePattern#

shape = [10, 10]
pattern = has_shape(shape)
assert isinstance(pattern, ShapePattern)
assert tvm.ir.structural_equal(pattern.shape, shape)

AttrPattern#

op = is_op("add").has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert op.attrs["TOpPattern"] == K_ELEMWISE

IfPattern#

x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

assert isinstance(pat, IfPattern)
assert isinstance(pat.cond, CallPattern)
assert isinstance(pat.true_branch, VarPattern)
assert isinstance(pat.false_branch, VarPattern)

LetPattern#

x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

assert isinstance(pat, LetPattern)
assert isinstance(pat.var, VarPattern)
assert isinstance(pat.value, CallPattern)
assert isinstance(pat.body, VarPattern)