UMAPartitioner

UMAPartitioner#

import set_env
import tvm
import tvm.relay as relay
from tvm.relay.backend.contrib.uma import uma_available
from tvm.relay.backend.contrib.uma.api import UMAPartitioner
from tvm.relay.op.contrib.register import get_pattern_table
from tvm.relay.testing import mlp, resnet

测试 partition_table#

partitioner = UMAPartitioner("test_partition")
assert get_pattern_table("test_partition") is None

partitioner.register()

assert get_pattern_table("test_partition") is not None
get_pattern_table("test_partition") 
[]
for workload, backend, merge in [
    ("resnet", "dnnl", False),
    ("resnet", "dnnl", True),
    ("mlp", "dnnl", False),
    ("mlp", "dnnl", True),
    ("resnet", "cutlass", False),
    ("resnet", "cutlass", True),
    ("mlp", "cutlass", False),
    ("mlp", "cutlass", True),
]:
    partitioner = UMAPartitioner(backend, merge)
    pattern_table = get_pattern_table(backend)
    # print(pattern_table)
    for entry in pattern_table:
        partitioner.add_pattern(*entry)

    if workload == "resnet":
        net = resnet.get_net(1, 10)
    elif workload == "mlp":
        net = mlp.get_net(1, 10)
    else:
        assert False, f"don't know how to find workload for {workload}"

    mod = tvm.ir.IRModule()
    mod["main"] = net

    partitioner.register()
    partitioned_mod = partitioner.partition(mod)

    def partition_default(mod):
        """partitions using default BYOC flow"""

        sequence = [
            relay.transform.MergeComposite(pattern_table),
            relay.transform.AnnotateTarget(backend),
        ]

        if merge:
            sequence.append(relay.transform.MergeCompilerRegions())

        sequence.append(relay.transform.PartitionGraph())
        sequential = tvm.transform.Sequential(sequence)

        return sequential(mod)

    default_partitioned_mod = partition_default(mod)

    assert len(partitioned_mod.functions) == len(default_partitioned_mod.functions)