函数分区

函数分区#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_partition_function():
    x = relay.var("x")
    w = relay.var("w")
    b = relay.var("b")

    x1 = relay.var("x1")
    w1 = relay.var("w1")

    wc_x = wildcard()
    wc_w = wildcard()
    wc_b = wildcard()
    wc_x1 = wildcard()
    wc_w1 = wildcard()

    func_pattern = FunctionPattern([wc_x1, wc_w1], is_op("nn.conv2d")(wc_x1, wc_w1))
    pattern = func_pattern(wc_x, wc_w) + wc_b

    func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
    expr = func(x, w) + b + b

    x2 = relay.var("x2")
    w2 = relay.var("w2")
    b2 = relay.var("b2")
    func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
        "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_"
    )
    expr2 = func2(x, w, b) + b
    assert tvm.ir.structural_equal(pattern.partition(expr), expr2)