for workload, backend, merge in [
("resnet", "dnnl", False),
("resnet", "dnnl", True),
("mlp", "dnnl", False),
("mlp", "dnnl", True),
("resnet", "cutlass", False),
("resnet", "cutlass", True),
("mlp", "cutlass", False),
("mlp", "cutlass", True),
]:
partitioner = UMAPartitioner(backend, merge)
pattern_table = get_pattern_table(backend)
# print(pattern_table)
for entry in pattern_table:
partitioner.add_pattern(*entry)
if workload == "resnet":
net = resnet.get_net(1, 10)
elif workload == "mlp":
net = mlp.get_net(1, 10)
else:
assert False, f"don't know how to find workload for {workload}"
mod = tvm.ir.IRModule()
mod["main"] = net
partitioner.register()
partitioned_mod = partitioner.partition(mod)
def partition_default(mod):
"""partitions using default BYOC flow"""
sequence = [
relay.transform.MergeComposite(pattern_table),
relay.transform.AnnotateTarget(backend),
]
if merge:
sequence.append(relay.transform.MergeCompilerRegions())
sequence.append(relay.transform.PartitionGraph())
sequential = tvm.transform.Sequential(sequence)
return sequential(mod)
default_partitioned_mod = partition_default(mod)
assert len(partitioned_mod.functions) == len(default_partitioned_mod.functions)