模糊 tuple 分区

模糊 tuple 分区#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_partition_fuzzy_tuple():
    x = relay.var("x")
    y = relay.var("y")
    z = x + y
    tuple_pattern = is_tuple(None)
    concat_pattern = is_op("concatenate")(tuple_pattern)

    xp = relay.var("xp")
    yp = relay.var("yp")
    zp = relay.var("zp")

    def create_func(args, body):
        return relay.Function(args, body).with_attr("PartitionedFromPattern", "Tuple_concatenate_")

    def concat(*args):
        return relay.op.concatenate(relay.expr.Tuple(args), axis=0)

    one = concat_pattern.partition(concat(x))
    assert tvm.ir.structural_equal(one, create_func([xp], concat(xp))(x))
    two = concat_pattern.partition(concat(x, y))
    assert tvm.ir.structural_equal(two, create_func([xp, yp], concat(xp, yp))(x, y))
    three = concat_pattern.partition(concat(x, y, z))
    assert tvm.ir.structural_equal(three, create_func([xp, yp, zp], concat(xp, yp, zp))(x, y, z))