forworkload,backend,mergein[("resnet","dnnl",False),("resnet","dnnl",True),("mlp","dnnl",False),("mlp","dnnl",True),("resnet","cutlass",False),("resnet","cutlass",True),("mlp","cutlass",False),("mlp","cutlass",True),]:partitioner=UMAPartitioner(backend,merge)pattern_table=get_pattern_table(backend)# print(pattern_table)forentryinpattern_table:partitioner.add_pattern(*entry)ifworkload=="resnet":net=resnet.get_net(1,10)elifworkload=="mlp":net=mlp.get_net(1,10)else:assertFalse,f"don't know how to find workload for {workload}"mod=tvm.ir.IRModule()mod["main"]=netpartitioner.register()partitioned_mod=partitioner.partition(mod)defpartition_default(mod):"""partitions using default BYOC flow"""sequence=[relay.transform.MergeComposite(pattern_table),relay.transform.AnnotateTarget(backend),]ifmerge:sequence.append(relay.transform.MergeCompilerRegions())sequence.append(relay.transform.PartitionGraph())sequential=tvm.transform.Sequential(sequence)returnsequential(mod)default_partitioned_mod=partition_default(mod)assertlen(partitioned_mod.functions)==len(default_partitioned_mod.functions)