模板匹配#

import set_env
import numpy as np

import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass

# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1

算子匹配#

assert is_op("add").match(relay.op.op.get("add"))
assert not is_op("add").match(relay.op.op.get("subtract"))
is_add_or_sub = is_op("add") | is_op("subtract")
assert is_add_or_sub.match(relay.op.op.get("add"))
assert is_add_or_sub.match(relay.op.op.get("subtract"))

回调匹配#

x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(is_var("x"), is_var("y"))
assert add_pattern.match(x + y)
assert add_pattern.match(y + x)
mul_pattern = is_op("multiply")(is_var("x"), is_var("y"))
assert mul_pattern.match(x * y)
assert mul_pattern.match(y * x)
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("subtract")(is_var("x"), is_var("y"))
assert add_pattern.match(x - y)
assert not add_pattern.match(y - x)
add_pattern = is_op("divide")(is_var("x"), is_var("y"))
assert add_pattern.match(x / y)
assert not add_pattern.match(y / x)
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
assert add_pattern.match(x + y)

# Match call with any number of inputs
call_pattern = wildcard()(None)
assert call_pattern.match(relay.op.nn.relu(x))
assert call_pattern.match(relay.op.add(x, y))
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
assert not add_pattern.match(x - y)

匹配函数#

x = relay.var("x")
y = relay.var("y")
wc1 = wildcard()
wc2 = wildcard()
func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
assert func_pattern.match(relay.Function([x, y], x + y))

# Match Function with any number of inputs
func_pattern = FunctionPattern(None, wildcard())
assert func_pattern.match(relay.Function([x], x))
assert func_pattern.match(relay.Function([x, y], x + y))
x = relay.var("x")
y = relay.var("y")
wc1 = wildcard()
wc2 = wildcard()
func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
assert not func_pattern.match(relay.Function([x, y], x - y))

匹配 if#

x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")
cond = x < y

assert pat.match(relay.expr.If(cond, x, y))
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")

assert not pat.match(relay.expr.If(x > y, x, y))
assert not pat.match(relay.expr.If(x < y, y, x))

匹配 let#

x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")

assert not pat.match(relay.expr.Let(lv, x > y, lv))
assert not pat.match(relay.expr.Let(lv, x < y, lv * x))

可选匹配#

x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
pattern = is_op("nn.relu")(
    is_op("nn.conv2d")(wildcard(), wildcard()).optional(
        lambda x: is_op("nn.bias_add")(x, wildcard())
    )
)

conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
assert pattern.match(relu)

conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
assert pattern.match(relu)

pattern = is_op("nn.conv2d")(wildcard(), wildcard())
pattern = pattern.optional(is_op("nn.relu")).optional(is_op("tanh"))

conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
tanh = relay.op.tanh(conv2d)
tanh2 = relay.op.tanh(relu)
relu2 = relay.op.nn.relu(tanh)
assert pattern.match(conv2d)
assert pattern.match(relu)
assert pattern.match(tanh)
assert pattern.match(tanh2)
assert not pattern.match(relu2)
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
pattern = is_op("nn.relu")(
    is_op("nn.conv2d")(wildcard(), wildcard()).optional(
        lambda x: is_op("nn.bias_add")(x, wildcard())
    )
)

conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.tanh(conv2d)
assert not pattern.match(relu)

conv2d = relay.op.nn.dense(x, w)
relu = relay.op.tanh(conv2d)
assert not pattern.match(relu)

conv2d = relay.op.nn.dense(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
assert not pattern.match(relu)

conv2d = relay.op.nn.conv2d(x, w)
bias_add = conv2d + w
relu = relay.op.nn.relu(bias_add)
assert not pattern.match(relu)

匹配常量#

conv2d = is_op("nn.conv2d")(wildcard(), is_constant())
pattern = is_op("nn.bias_add")(conv2d, wildcard())

x = relay.var("x", shape=(1, 3, 224, 224))
w = relay.var("w", shape=(3, 3, 3, 3))
b = relay.var("b", shape=(3,))
conv2d = relay.op.nn.conv2d(x, w)
out = relay.op.nn.bias_add(conv2d, b)
func = relay.Function([x, w, b], out)
mod = tvm.IRModule.from_expr(func)

assert not pattern.match(mod["main"].body)
mod["main"] = bind_params_by_name(mod["main"], {"w": tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
assert pattern.match(mod["main"].body)

匹配元组#

x = relay.var("x")
y = relay.var("y")
z = relay.op.op.get("add")
tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))

tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))

tuple_get_item_pattern = is_tuple_get_item(tuple_pattern)  # Match any index
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 0))
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))

# Match tuple with any inputs
tuple_pattern = is_tuple(None)
concat_pattern = is_op("concatenate")(tuple_pattern)
assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x,)), axis=0))
assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y)), axis=0))
assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y, z)), axis=0))
x = relay.var("x")
y = relay.var("y")
z = relay.op.op.get("add")
tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"), wildcard()))
assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))

tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))

匹配类型#

x = relay.var("x", shape=(10, 10), dtype="float32")
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert ty_pat.match(x)
x = relay.var("x", shape=(10, 10), dtype="int32")
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert not ty_pat.match(x)
x = relay.var("x", shape=(10, 10), dtype="float32")
ty_pat = has_dtype("float32")
assert ty_pat.match(x)
x = relay.var("x", shape=(10, 10), dtype="int32")
ty_pat = has_dtype("float32")
assert not ty_pat.match(x)

匹配形状#

x = relay.var("x", shape=(10, 10), dtype="float32")
ty_pat = has_shape((10, 10))
assert ty_pat.match(x)
x = relay.var("x", shape=(10, 10), dtype="int32")
ty_pat = has_shape((10, 5))
assert not ty_pat.match(x)

匹配算子属性#

op = is_op("add").has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var("x")
y = relay.var("y")
assert op_pat.match(x + y)
op = is_op("nn.dense").has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
x = relay.var("x")
y = relay.var("y")
assert not op_pat.match(relay.op.nn.dense(x, y))
op = is_op("add").has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var("x")
y = relay.var("y")
assert not op_pat.match(x - y)
z = relay.var("z")
assert not op_pat.match(relay.Let(z, x + y, z))

匹配函数属性#

pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var("x")
y = relay.var("y")
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
assert pattern.match(f)
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var("x")
y = relay.var("y")

f = relay.Function([x, y], x + y).with_attr("RandomTest", "add")
assert not pattern.match(f)
f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
assert not pattern.match(f)

匹配回调属性#

# String attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
x = relay.var("x")
y = relay.var("y")
assert is_conv2d.match(relay.op.nn.conv2d(x, y))

# Array attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
out = relay.op.nn.conv2d(x, y, kernel_size=[3, 3])
assert is_conv2d.match(out)

# non-operator call
attr_dict = {"call_attr": "attr"}
call_has_attr = wildcard()(wildcard()).has_attr(attr_dict)
call_attr = tvm.ir.make_node("DictAttrs", **attr_dict)
a = relay.Var("a")
b = relay.Var("b")
assert call_has_attr.match(relay.Call(a, [b], attrs=call_attr))

# empty attrs should match anything
empty_attrs = tvm.ir.make_node("DictAttrs", **{})
call_has_empty_attrs = wildcard()(wildcard()).has_attr({})
assert call_has_empty_attrs.match(relay.Call(a, [b], attrs=empty_attrs))
assert call_has_empty_attrs.match(relay.Call(a, [b], attrs=call_attr))
x = relay.var("x")
y = relay.var("y")

is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

# Array attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
out = relay.op.nn.conv2d(x, y, kernel_size=[2, 1])
assert not is_conv2d.match(out)

# non-operator calls
call_has_attr = wildcard()(wildcard()).has_attr({"call_attr": "attr"})
wrong_key = tvm.ir.make_node("DictAttrs", **{"wrong": "attr"})
wrong_value = tvm.ir.make_node("DictAttrs", **{"call_attr": "wrong"})
empty_attrs = tvm.ir.make_node("DictAttrs", **{})

a = relay.Var("a")
b = relay.Var("b")
# attrs left undefined
assert not call_has_attr.match(relay.Call(a, [b]))
# wrong attrs
assert not call_has_attr.match(relay.Call(a, [b], attrs=wrong_key))
assert not call_has_attr.match(relay.Call(a, [b], attrs=wrong_value))
assert not call_has_attr.match(relay.Call(a, [b], attrs=empty_attrs))
is_cast = is_op("cast")(wildcard()).has_attr({"dtype": "float32"})
x = relay.var("x")
assert is_cast.match(relay.op.cast(x, "float32"))

匹配 diamond#

# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
path1 = is_op("nn.relu")(is_conv2d)
path2 = is_op("nn.leaky_relu")(is_conv2d)
diamond = is_op("add")(path1, path2)

# Expr
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert diamond.match(out)
# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
path1 = is_op("nn.relu")(is_conv2d)
path2 = is_op("nn.leaky_relu")(is_conv2d)
diamond = is_op("add")(path1, path2)

# Expr
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)

# Check
assert not diamond.match(leaky_relu)
assert not diamond.match(relu)

fake_diamond:

# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
path1 = is_op("nn.relu")(is_conv2d)
path2 = is_op("nn.leaky_relu")(is_conv2d)
diamond = is_op("add")(path1, path2)

# Expr
input1 = relay.var("input1")
weight1 = relay.var("weight1")
conv2d1 = relay.op.nn.conv2d(input1, weight1)
inp2 = relay.var("input2")
weight2 = relay.var("weight2")
conv2d2 = relay.op.nn.conv2d(inp2, weight2)
relu = relay.op.nn.relu(conv2d1)
leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
out = relu + leaky_relu

# Check
assert not diamond.match(out)

at_most_one_parent 匹配#

# Pattern
P = is_op("nn.conv2d")(wildcard(), wildcard())  # 'parent'
I = is_op("nn.relu")(wildcard())  # 'intermediate' ('path' in the code)
C = is_op("add")(wildcard(), wildcard())  # 'child'
pattern = dominates(P, I, C)

#       n6(P)
#      /  \
#     n7   \
#    /      \
#    n8(P)  n10(I)
#    \      /
#    n9(I) /
#      \  /
#      n11(C)

x = relay.var("x")
w = relay.var("w")
n6 = relay.op.nn.conv2d(x, w)  # matches P
n7 = relay.op.tanh(n6)  # does not match I
n8 = relay.op.nn.conv2d(n7, w)  # matches P
n9 = relay.op.nn.relu(n8)  # matches I
n10 = relay.op.nn.relu(n6)  # matches I
n11 = relay.add(n9, n10)  # matches C

# Does not match: Can't match the parent pattern P at both 8 and 6.
# Note that if we did allow P to be used twice the implementation would
# need to be changed to not 'jump over' n7.
assert not pattern.match(n11)

匹配 dominator#

# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op("add")(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

# Classic Diamond
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert diamond.match(out)

# Deeper Branch
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
relu = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert diamond.match(out)

# Single Branch
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
tanh = relay.op.tanh(relu)
out = relu + tanh

# Check
assert diamond.match(out)

# Fuzzy path/nested Diamond
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op(
    "add"
)(wildcard(), wildcard())
reduction = is_op("add")(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relu + relu
tanh = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = tanh + leaky_relu

assert diamond.match(out)
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op("add")(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

# Fake Diamond
input1 = relay.var("input1")
weight1 = relay.var("weight1")
conv2d1 = relay.op.nn.conv2d(input1, weight1)
inp2 = relay.var("input2")
weight2 = relay.var("weight2")
conv2d2 = relay.op.nn.conv2d(inp2, weight2)
relu = relay.op.nn.relu(conv2d1)
leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
out = relu + leaky_relu

# Check
assert not diamond.match(out)

# Add op that doesn't match K_ELEMWISE
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relu + relu
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert not diamond.match(out)

# Relu on the input instead of the conv
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(inp)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert not diamond.match(out)

# No conv
inp = relay.var("input")
relu = relay.op.nn.relu(inp)
relu = relay.op.nn.relu(relu)
tanh = relay.op.tanh(relu)
out = relu + tanh

# Check
assert not diamond.match(out)
# Pattern
P = is_op("nn.conv2d")(wildcard(), wildcard())  # 'parent'
I = is_op("nn.relu")(wildcard())  # 'intermediate' ('path' in the code)
C = is_op("add")(wildcard(), wildcard())  # 'child'
pattern = dominates(P, I, C)

#       n6(P)
#      /  \
#     n7   \
#    /      \
#    n8(P)  n9(I)
#    \      /
#     \    /
#      \  /
#      n10(C)

x = relay.var("x")
w = relay.var("w")
n6 = relay.op.nn.conv2d(x, w)  # matches P
n7 = relay.op.tanh(n6)  # does not match I
n8 = relay.op.nn.conv2d(n7, w)  # matches P
n9 = relay.op.nn.relu(n6)  # matches I
n10 = relay.add(n8, n9)  # matches C

# Does not match: Can't match the parent pattern P at both 8 and 6.
# Note that if we did allow P to be used twice the implementation would
# need to be changed to not 'jump over' n7.
assert not pattern.match(n10)

带有类型:

# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype(
    "float32"
)
reduction = is_op("add")(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

# Classic Diamond
inp = relay.var("input", relay.TensorType((1, 3, 12, 12), "float32"))
weight = relay.var("weight", relay.TensorType((3, 3, 3, 3), "float32"))
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert diamond.match(out)
# Classic Diamond
inp = relay.var("input", relay.TensorType((1, 3, 12, 12), "float32"))
weight = relay.var("weight", relay.TensorType((3, 3, 3, 3), "float32"))
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype(
    "float32"
)
reduction = is_op("add")(wildcard(), wildcard()).has_shape([1, 1, 10, 10])
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

# Check
assert not diamond.match(out)

# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype(
    "float16"
)
reduction = is_op("add")(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

# Check
assert not diamond.match(out)
##