defgen_ground_truth(mod,target,dev,inputs):# Lower and run tuning# Since there is no default schedule for GPU in MS yet, this is necessarywithtarget:seq=tvm.transform.Sequential([relax.transform.LegalizeOps(),tir.transform.DefaultGPUSchedule()])new_mod=seq(mod)assertrelax.analysis.well_formed(new_mod)exec=tvm.compile(new_mod,target,params={})vm=relax.VirtualMachine(exec,dev)returnvm["main"](*inputs)
@tvm.script.ir_moduleclassInputModule:@R.functiondefmain(x:R.Tensor((16,16),"float32"),y:R.Tensor((16,16),"float32"))->R.Tensor((16,16),"float32"):withR.dataflow():z1=R.multiply(x,y)z2=R.add(z1,x)z3=R.add(z1,z2)z4=R.multiply(z3,z2)z5=R.add(z4,z1)R.output(z5)returnz5defsetup_test():# Prepare IRModule and its inputmod=InputModuleassertisinstance(mod,tvm.IRModule)np0=np.random.rand(16,16).astype(np.float32)np1=np.random.rand(16,16).astype(np.float32)data0=tvm.nd.array(np0,dev)data1=tvm.nd.array(np1,dev)inputs=[data0,data1]# Ground truth should be generated before annotation# due to the conflict with MS task extraction# TODO(@sunggg): Sort this outexpected=gen_ground_truth(mod,target,dev,inputs)returnmod,inputs,expectedentry_func_name=tvm.testing.parameter("main","func")@tvm.testing.requires_gpu@requires_tensorrt_runtimedeftest_tensorrt_only(entry_func_name):mod,inputs,expected=setup_test()ifentry_func_name!="main":mod[entry_func_name]=moddelmod["main"]# Define patterns that we want to offload to byoc# This test will offload entire model# Thus, define patterns for both `multiply` and `add` opspatterns=[("tensorrt.multiply",is_op("relax.multiply")(wildcard(),wildcard())),("tensorrt.add",is_op("relax.add")(wildcard(),wildcard())),]new_mod=tvm.transform.Sequential([relax.transform.FuseOpsByPattern(patterns),relax.transform.MergeCompositeFunctions(),relax.transform.RunCodegen(),])(mod)ex0=tvm.compile(new_mod,target,params={})# Sanity check for the correctness and roundtripcheck_roundtrip(ex0,dev,inputs,expected,entry_func_name)@tvm.testing.requires_gpu@requires_tensorrt_runtimedeftest_mix_use_tensorrt_and_tvm():mod,inputs,expected=setup_test()# Define patterns that we want to offload to byoc# This test will only offload `add` op to tensorrt# and tune `multiply` op with MetaSchedulepatterns=[("tensorrt.add",is_op("relax.add")(wildcard(),wildcard())),]# Run Codegen passwithtempfile.TemporaryDirectory()aswork_dir:withtarget,tvm.transform.PassContext(trace=Trace(mod),opt_level=0):new_mod=tvm.transform.Sequential([relax.transform.FuseOpsByPattern(patterns),relax.transform.MergeCompositeFunctions(),relax.transform.RunCodegen(),relax.transform.LegalizeOps(),relax.transform.MetaScheduleTuneIRMod(params={},work_dir=work_dir,max_trials_global=8),relax.transform.MetaScheduleApplyDatabase(work_dir),])(mod)assertrelax.analysis.well_formed(new_mod)withtransform.PassContext(opt_level=0):ex0=tvm.compile(new_mod,target,params={})# Sanity check for the correctness and roundtripcheck_roundtrip(ex0,dev,inputs,expected)@tvm.script.ir_moduleclassConv2dx2:@R.functiondefmain(data:R.Tensor((16,32,32,16),dtype="float16"),weight1:R.Tensor((16,3,3,16),dtype="float16"),weight2:R.Tensor((16,3,3,16),dtype="float16"),)->R.Tensor((16,32,32,16),dtype="float16"):cls=Conv2dx2withR.dataflow():lv:R.Tensor((16,32,32,16),dtype="float16")=cls.fused_relax_nn_conv2d_tensorrt(data,weight1)gv:R.Tensor((16,32,32,16),dtype="float16")=cls.fused_relax_nn_conv2d_tensorrt(lv,weight2)R.output(gv)returngv@R.functiondeffused_relax_nn_conv2d_tensorrt(data:R.Tensor((16,32,32,16),dtype="float16"),weight1:R.Tensor((16,3,3,16),dtype="float16"),)->R.Tensor((16,32,32,16),dtype="float16"):R.func_attr({"Codegen":"tensorrt","global_symbol":"fused_relax_nn_conv2d_tensorrt"})@R.functiondefgv(data_1:R.Tensor((16,32,32,16),dtype="float16"),weight1_1:R.Tensor((16,3,3,16),dtype="float16"),)->R.Tensor((16,32,32,16),dtype="float16"):R.func_attr({"Composite":"tensorrt.conv2d","Primitive":1})withR.dataflow():gv_1:R.Tensor((16,32,32,16),dtype="float16")=R.nn.conv2d(data_1,weight1_1,padding=[1,1,1,1],data_layout="NHWC",kernel_layout="OHWI",out_layout="NHWC",)R.output(gv_1)returngv_1gv1:R.Tensor((16,32,32,16),dtype="float16")=gv(data,weight1)returngv1@tvm.script.ir_moduleclassConv2dx2_after:@R.functiondefmain(data:R.Tensor((16,32,32,16),dtype="float16"),weight1:R.Tensor((16,3,3,16),dtype="float16"),weight2:R.Tensor((16,3,3,16),dtype="float16"),)->R.Tensor((16,32,32,16),dtype="float16"):withR.dataflow():lv=R.call_dps_packed("fused_relax_nn_conv2d_tensorrt",(data,weight1),out_sinfo=R.Tensor((16,32,32,16),dtype="float16"),)gv=R.call_dps_packed("fused_relax_nn_conv2d_tensorrt",(lv,weight2),out_sinfo=R.Tensor((16,32,32,16),dtype="float16"),)R.output(gv)returngvdeftest_multiple_calls_same_extern():mod=relax.transform.RunCodegen()(Conv2dx2)tvm.ir.assert_structural_equal(mod["main"],Conv2dx2_after["main"])deftest_default_entry_func():"""The entry function is not necessarily named "main" Like `test_multiple_calls_same_extern`, but the main function is named "func". """before_with_main=Conv2dx2after_with_main=relax.transform.RunCodegen()(before_with_main)defrename_main(mod):mod=mod.clone()mod["func"]=mod["main"].with_attr("global_symbol","func")delmod["main"]returnmodbefore_with_func=rename_main(before_with_main)expected_with_func=rename_main(after_with_main)after_with_func=relax.transform.RunCodegen()(before_with_func)tvm.ir.assert_structural_equal(expected_with_func["func"],after_with_func["func"])deftest_dynamic_shape():importtvm.relax.backend.cuda.cublas@I.ir_moduleclassBefore:@R.functiondefmain(x:R.Tensor((1,4096),dtype="float16"),w1:R.Tensor((4096,"r1"),dtype="float16"),w2:R.Tensor((4096,"r2"),dtype="float16"),)->R.Tuple(R.Tensor((1,"r1"),dtype="float16"),R.Tensor((1,"r2"),dtype="float16")):r1=T.int64()r2=T.int64()cls=BeforewithR.dataflow():lv:R.Tensor((1,r1),dtype="float16")=cls.fused_relax_matmul_cublas(x,w1)lv1:R.Tensor((1,r2),dtype="float16")=cls.fused_relax_matmul_cublas(x,w2)gv:R.Tuple(R.Tensor((1,r1),dtype="float16"),R.Tensor((1,r2),dtype="float16"))=(lv,lv1)R.output(gv)returngv@R.functiondeffused_relax_matmul_cublas(x:R.Tensor((1,4096),dtype="float16"),w1:R.Tensor((4096,"r1"),dtype="float16"))->R.Tensor((1,"r1"),dtype="float16"):r1=T.int64()R.func_attr({"Codegen":"cublas"})@R.functiondefgv(x_1:R.Tensor((1,4096),dtype="float16"),w1_1:R.Tensor((4096,r1),dtype="float16"),)->R.Tensor((1,r1),dtype="float16"):R.func_attr({"Composite":"cublas.matmul"})withR.dataflow():gv_1:R.Tensor((1,r1),dtype="float16")=R.matmul(x_1,w1_1,out_dtype="void")R.output(gv_1)returngv_1gv1:R.Tensor((1,r1),dtype="float16")=gv(x,w1)returngv1@I.ir_moduleclassExpected:@R.functiondefmain(x:R.Tensor((1,4096),dtype="float16"),w1:R.Tensor((4096,"r1"),dtype="float16"),w2:R.Tensor((4096,"r2"),dtype="float16"),)->R.Tuple(R.Tensor((1,"r1"),dtype="float16"),R.Tensor((1,"r2"),dtype="float16")):r1=T.int64()r2=T.int64()withR.dataflow():lv=R.call_dps_packed("fused_relax_matmul_cublas",(x,w1),out_sinfo=R.Tensor((1,r1),dtype="float16"),)lv1=R.call_dps_packed("fused_relax_matmul_cublas",(x,w2),out_sinfo=R.Tensor((1,r2),dtype="float16"),)gv:R.Tuple(R.Tensor((1,r1),dtype="float16"),R.Tensor((1,r2),dtype="float16"))=(lv,lv1)R.output(gv)returngvafter=relax.transform.RunCodegen()(Before)tvm.ir.assert_structural_equal(after["main"],Expected["main"])deftest_no_op_for_call_to_tir():"""Calls to PrimFunc are ignored RunCodegen should only update calls to Relax functions annotated with the `"Codegen"` attribute. Calls to any other function type should be ignored. This is a regression test. Previous implementations performed an unconditional cast from `tvm::BaseFunc` to `tvm::relax::Function`, which produced an error. """@tvm.script.ir_moduleclassBefore:@R.functiondefmain(x:R.Tensor([4],"int64")):R.func_attr({"relax.force_pure":True})_=Before.shape_func(x)returnx@T.prim_func(private=True)defshape_func(H:T.Buffer(T.int64(4),"int64")):H[T.int64(0)]=H[T.int64(0)]+T.int64(1)Expected=BeforeAfter=relax.transform.RunCodegen()(Before)tvm.ir.assert_structural_equal(Expected,After)# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding)if__name__=="__main__":pytest.main([__file__])