可选函数分区

可选函数分区#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_partition_optional_function():
    x = relay.var("x")
    w = relay.var("w")
    b = relay.var("b")

    x1 = relay.var("x1")
    w1 = relay.var("w1")

    wc_x = wildcard()
    wc_w = wildcard()
    wc_x1 = wildcard()
    wc_w1 = wildcard()

    func_pattern0 = FunctionPattern(
        [wc_x1, wc_w1], is_op("sigmoid")(is_op("nn.conv2d")(wc_x1, wc_w1))
    )
    func_pattern1 = FunctionPattern(
        [wc_x1, wc_w1], is_op("nn.relu")(is_op("nn.conv2d")(wc_x1, wc_w1))
    )
    pattern = func_pattern0(wc_x, wc_w) | func_pattern1(wc_x, wc_w)

    func = relay.Function([x1, w1], relay.nn.relu(relay.nn.conv2d(x1, w1)))
    expr = func(x, w) + b

    x2 = relay.var("x2")
    w2 = relay.var("w2")
    func2 = relay.Function([x2, w2], func(x2, w2)).with_attr(
        "PartitionedFromPattern", "nn.conv2d_nn.relu_FunctionCall_"
    )
    expr2 = func2(x, w) + b
    assert tvm.ir.structural_equal(pattern.partition(expr), expr2)