模糊函数参数的分区

模糊函数参数的分区#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_partition_fuzzy_function_args():
    func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard()
    x = relay.var("x")
    y = relay.var("y")
    z = relay.var("z")
    b = relay.var("b")
    xp = relay.var("xp")
    yp = relay.var("yp")
    zp = relay.var("zp")

    def create_func(call):
        N = len(call.op.params)
        new_params = [relay.var(str(i)) for i in range(N + 1)]
        label = "add_FunctionCall_add_"
        if N == 3:
            label = "add_" + label
        return relay.Function(
            new_params, relay.Call(call.op, (new_params[0:-1])) + new_params[-1]
        ).with_attr("PartitionedFromPattern", label)(*([x, y, z][0:N] + [b]))

    f1 = relay.Function([xp], xp + xp)(x)
    one = func_pattern.partition(f1 + b)
    assert tvm.ir.structural_equal(one, create_func(f1))
    f2 = relay.Function([xp, yp], xp + yp)(x, y)
    two = func_pattern.partition(f2 + b)
    assert tvm.ir.structural_equal(two, create_func(f2))
    f3 = relay.Function([xp, yp, zp], xp + yp + zp)(x, y, z)
    three = func_pattern.partition(f3 + b)
    assert tvm.ir.structural_equal(three, create_func(f3))