分析提取融合函数#
import tvm
from tvm import relay
from tvm.relay.testing.synthetic import get_workload
def get_conv_net():
"""获取 `fuse_ops.cc` 中描述情况的网络:
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
"""
dshape = (1, 1, 5, 1)
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x2 = relay.nn.conv2d(y, relay.var("w3"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x3 = relay.nn.conv2d(y, relay.var("w4"), kernel_size=(3, 3), padding=(1, 1), channels=1)
z = relay.add(x1, x2)
z = relay.add(x3, z)
return tvm.IRModule.from_expr(z)
def get_conv2d():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
return tvm.IRModule.from_expr(y)
def test_extract_identity():
mod = get_conv2d()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 1
mod["main"] = mod["main"].with_attr("Primitive", tvm.tir.IntImm("int32", 1))
tvm.ir.structural_equal(list(items.values())[0], mod["main"])
def test_extract_conv_net():
mod = get_conv_net()
items = relay.analysis.extract_fused_functions(mod)
functions = list(items.values())
assert len(functions) == 2
x = functions[0]
y = functions[1]
def is_conv(func):
conv2d = relay.op.op.get("nn.conv2d")
call_node = func.body
return call_node.op == conv2d
def is_conv_add(func):
add = relay.op.op.get("add")
call_node = func.body
maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0])
return call_node.op == add and is_conv(maybe_conv_module["main"])
# Function traversal order isn't obvious, so checking both orders is more consistent
assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y))
def test_extract_resnet():
mod, _params = get_workload()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 7
if __name__ == "__main__":
test_extract_identity()
test_extract_conv_net()
test_extract_resnet()
mod = get_conv_net()
items = relay.analysis.extract_fused_functions(mod)
mod.show()
def @main(%x: Tensor[(1, 1, 5, 1), float32], %w1, %w4, %w2, %w3) {
%0 = nn.conv2d(%x, %w1, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
%1 = nn.conv2d(%0, %w2, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
%2 = nn.conv2d(%0, %w3, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
%3 = nn.conv2d(%0, %w4, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
%4 = add(%1, %2);
add(%3, %4)
}