# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.""" Test translate for TensorrRT. """importpytestimportnumpyasnpimporttorchfromtorchimportfxfromtorch.nnimportModuleimporttvm.testingfromtvm.relaximportPyExprVisitorfromtvm.relax.frontend.torchimportfrom_fxfromtvm.contrib.msc.framework.tensorrt.frontendimporttranslatefromtvm.contrib.msc.framework.tensorrtimportcodegenfromtvm.contrib.msc.coreimportutilsasmsc_utilsrequires_tensorrt=pytest.mark.skipif(tvm.get_global_func("relax.ext.tensorrt",True)isNone,reason="TENSORRT is not enabled",)defbuild_and_run(mod,inputs):"""Build and run the virtual machine"""target=tvm.target.Target("cuda")mod=tvm.relax.transform.LegalizeOps()(mod)withtarget:mod=tvm.tir.transform.DefaultGPUSchedule()(mod)withtvm.transform.PassContext(opt_level=3):rt_mod=tvm.relax.build(mod,target)runnable=tvm.relax.VirtualMachine(rt_mod,tvm.cuda())res=runnable["main"](*inputs)ifisinstance(res,tvm.runtime.NDArray):return[res.asnumpy()]return[e.asnumpy()foreinres]defcheck_names(mod):"""Check the byoc name and unique_name"""@tvm.relax.expr_functor.visitorclassNameChecker(PyExprVisitor):"""Checker to check if any non-target ops exist"""defcheck(self,expr):self._recorded_names=set()ifisinstance(expr,tvm.relax.Expr):self.visit_expr(expr)elifisinstance(expr,tvm.relax.BindingBlock):self.visit_binding_block(expr)defvisit_function_(self,op:tvm.relax.Function)->None:if"Composite"inop.attrs:assert"Unique"inop.attrs,"Can not find unique_name for func "+str(op)name=str(op.attrs["Unique"])assertnamenotinself._recorded_names,"Name {} is already in use".format(name)self._recorded_names.add(name)super().visit_function_(op)def_is_target_func(func):if"Codegen"notinfunc.attrs:returnFalsereturnfunc.attrs["Codegen"]=="msc_tensorrt"for_,funcinmod.functions.items():ifnot_is_target_func(func):continueassert"Unique"infunc.attrs,"Can not find Unique from function attributes"NameChecker().check(func)defverify_model(torch_model,input_info,allow_incomplete=False):"""Build model and verify results"""graph_model=fx.symbolic_trace(torch_model)datas=[np.random.rand(*i[0]).astype(i[1])foriininput_info]torch_datas=[torch.from_numpy(i)foriindatas]withtorch.no_grad():golden=torch_model(*torch_datas)mod=from_fx(graph_model,input_info)ifnotisinstance(golden,(list,tuple)):golden=[golden]golden=[g.detach().cpu().numpy()forgingolden]# partition module for tensorrtmod,graphs,weights=translate.partition_for_tensorrt(mod,trans_config={"allow_incomplete":allow_incomplete})check_names(mod)output_folder=msc_utils.msc_dir()# tranalte to tensorrtmod=codegen.to_tensorrt(mod,graphs,weights,output_folder=output_folder)tvm_datas=[tvm.nd.array(i,device=tvm.cuda())foriindatas]results=build_and_run(mod,tvm_datas)forgol,resinzip(golden,results):tvm.testing.assert_allclose(gol,res,atol=1e-3,rtol=1e-3)output_folder.destory()@requires_tensorrtdeftest_conv1d():"""test tensorrt translator for conv1d"""classConv1D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv1D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10],"float32")]verify_model(Conv1D1(),input_info)verify_model(Conv1D2(),input_info)@requires_tensorrtdeftest_conv2d():"""test tensorrt translator for conv2d"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv2D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10,10],"float32")]verify_model(Conv2D1(),input_info)verify_model(Conv2D2(),input_info)@requires_tensorrtdeftest_linear():"""test tensorrt translator for linear"""classDense1(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=True)defforward(self,data):returnself.linear(data)classDense2(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=False)defforward(self,data):returnself.linear(data)classMatMul1(Module):defforward(self,x,y):returntorch.matmul(x,y)input_info=[([1,3,10,10],"float32")]verify_model(Dense1(),input_info)verify_model(Dense2(),input_info)verify_model(MatMul1(),[([10,10],"float32"),([10,10],"float32")])@requires_tensorrtdeftest_bmm():"""test tensorrt translator for bmm"""classBMM(Module):defforward(self,x,y):returntorch.bmm(x,y)input_info=[((4,128,256),"float32"),((4,256,512),"float32")]verify_model(BMM(),input_info)@requires_tensorrtdeftest_baddbmm():"""test tensorrt translator for baddbmm"""classBAddBMM1(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y)classBAddBMM2(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y,alpha=2,beta=0)input_info=[((4,128,512),"float32"),((4,128,256),"float32"),((4,256,512),"float32"),]verify_model(BAddBMM1(),input_info)verify_model(BAddBMM2(),input_info)@requires_tensorrtdeftest_relu():"""test tensorrt translator for relu"""classReLU(Module):def__init__(self):super().__init__()self.relu=torch.nn.ReLU()defforward(self,data):returnself.relu(data)input_info=[([10,10],"float32")]verify_model(ReLU(),input_info)@requires_tensorrtdeftest_relu6():"""test tensorrt translator for relu6"""classReLU6(Module):def__init__(self):super().__init__()self.relu6=torch.nn.ReLU6()defforward(self,data):returnself.relu6(data)input_info=[([10,10],"float32")]verify_model(ReLU6(),input_info)@requires_tensorrtdeftest_maxpool2d():"""test tensorrt translator for maxpool2d"""classMaxPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classMaxPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[4,4],padding=2,stride=2)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]verify_model(MaxPool2d(),input_info)verify_model(MaxPool2d2(),input_info)@requires_tensorrtdeftest_avgpool2d():"""test tensorrt translator for avgpool2d"""classAvgPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classAvgPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[4,4],stride=2,padding=2,ceil_mode=True)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]verify_model(AvgPool2d(),input_info)verify_model(AvgPool2d2(),input_info)@requires_tensorrtdeftest_adaptive_avgpool2d():"""test tensorrt translator for adaptive_avgpool2d"""classAdaptiveAvgPool2d0(Module):def__init__(self):super().__init__()self.pool=torch.nn.AdaptiveAvgPool2d([10,10])defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]verify_model(AdaptiveAvgPool2d0(),input_info)@requires_tensorrtdeftest_flatten():"""test tensorrt translator for flatten"""classFlatten(Module):def__init__(self):super().__init__()self.f=torch.nn.Flatten(2,-1)defforward(self,data):returnself.f(data)input_info=[([1,3,10,10],"float32")]verify_model(Flatten(),input_info)verify_model(torch.nn.Flatten(2,-1),input_info)@requires_tensorrtdeftest_batchnorm2d():"""test tensorrt translator for batchnorm2d"""classBatchNorm2d(Module):def__init__(self):super().__init__()self.batchnorm=torch.nn.BatchNorm2d(3)defforward(self,data):returnself.batchnorm(data)input_info=[([1,3,10,10],"float32")]verify_model(BatchNorm2d().eval(),input_info)@requires_tensorrtdeftest_embedding():"""test tensorrt translator for embedding"""classEmbedding(Module):def__init__(self):super().__init__()self.embedding=torch.nn.Embedding(10,3)defforward(self,data):returnself.embedding(data)verify_model(Embedding(),[([4],"int64")],allow_incomplete=True)verify_model(Embedding(),[([4,5],"int64")],allow_incomplete=True)@requires_tensorrtdeftest_layernorm():"""test tensorrt translator for layernorm"""classLayerNorm(Module):def__init__(self):super().__init__()self.layernorm=torch.nn.LayerNorm((10,10))defforward(self,data):returnself.layernorm(data)input_info=[([1,3,10,10],"float32")]verify_model(LayerNorm(),input_info)@requires_tensorrtdeftest_silu():"""test tensorrt translator for silu"""classSiLU(Module):def__init__(self):super().__init__()self.silu=torch.nn.SiLU()defforward(self,data):returnself.silu(data)input_info=[([1,3,10,10],"float32")]verify_model(SiLU(),input_info)@requires_tensorrtdeftest_groupnorm():"""test tensorrt translator for groupnorm"""classGroupNorm(Module):def__init__(self):super().__init__()self.groupnorm=torch.nn.GroupNorm(3,3)defforward(self,data):returnself.groupnorm(data)input_info=[([1,3,10,10],"float32")]verify_model(GroupNorm(),input_info)@requires_tensorrtdeftest_softmax():"""test tensorrt translator for softmax"""classSoftmax(Module):def__init__(self):super().__init__()self.softmax=torch.nn.Softmax(dim=1)defforward(self,data):returnself.softmax(data)input_info=[([1,3,10,10],"float32")]verify_model(Softmax(),input_info)@requires_tensorrtdeftest_binary():"""test tensorrt translator for binary"""input_info1=[([1,3,10,10],"float32"),([1,3,10,10],"float32")]input_info2=[([1,3,10,10],"float32")]# AddclassAdd1(Module):defforward(self,lhs,rhs):returnlhs+rhsclassAdd2(Module):defforward(self,lhs):returnlhs+1.0verify_model(Add1(),input_info1)verify_model(Add2(),input_info2)# SubclassSub1(Module):defforward(self,lhs,rhs):returnlhs-rhsclassSub2(Module):defforward(self,lhs):returnlhs-1.0verify_model(Sub1(),input_info1)verify_model(Sub2(),input_info2)# MulclassMul1(Module):defforward(self,lhs,rhs):returnlhs*rhsclassMul2(Module):defforward(self,lhs):returnlhs*1.0verify_model(Mul1(),input_info1)verify_model(Mul2(),input_info2)# True divclassTrueDiv1(Module):defforward(self,lhs,rhs):returnlhs/rhsclassTrueDiv2(Module):defforward(self,lhs):returnlhs/1.0verify_model(TrueDiv1(),input_info1)verify_model(TrueDiv2(),input_info2)# Floor divclassFloorDiv1(Module):defforward(self,lhs,rhs):returnlhs//rhsclassFloorDiv2(Module):defforward(self,lhs):returnlhs//1.0verify_model(FloorDiv1(),input_info1)verify_model(FloorDiv2(),input_info2)# PowerclassPower1(Module):defforward(self,lhs,rhs):returnlhs**rhsclassPower2(Module):defforward(self,lhs):returnlhs**1.0verify_model(Power1(),input_info1)verify_model(Power2(),input_info2)@requires_tensorrtdeftest_squeeze():"""test tensorrt translator for squeeze"""classSqueeze1(Module):defforward(self,data):returndata.squeeze(1)classSqueeze2(Module):defforward(self,data):returndata.squeeze()input_info=[([3,1,4,1],"float32")]verify_model(Squeeze1(),input_info)verify_model(Squeeze2(),input_info)@requires_tensorrtdeftest_unsqueeze():"""test tensorrt translator for unsqueeze"""classUnsqueeze1(Module):defforward(self,data):returndata.unsqueeze(1)classUnsqueeze2(Module):defforward(self,data):returndata.unsqueeze(-1)input_info=[([1,3,10,10],"float32")]verify_model(Unsqueeze1(),input_info)verify_model(Unsqueeze2(),input_info)@requires_tensorrtdeftest_getitem():"""test tensorrt translator for getitem"""classSlice1(Module):defforward(self,x):returnx[0:1,1::2,:,:3]classSlice2(Module):defforward(self,x):returnx[:,None,None,:,None]verify_model(Slice1(),[([1,3,10,10],"float32")])verify_model(Slice2(),[([8,16],"float32")])@requires_tensorrtdeftest_unary():"""test tensorrt translator for unary"""input_info=[([1,3,10,10],"float32")]# sinclassSin(Module):defforward(self,data):returntorch.sin(data)verify_model(Sin(),input_info)# cosclassCos(Module):defforward(self,data):returntorch.cos(data)verify_model(Cos(),input_info)# expclassExp(Module):defforward(self,data):returntorch.exp(data)verify_model(Exp(),input_info)# sqrtclassSqrt(Module):defforward(self,data):returntorch.sqrt(data)verify_model(Sqrt(),input_info)# sigmoidclassSigmoid(Module):defforward(self,data):returntorch.sigmoid(data)verify_model(Sigmoid(),input_info)# roundclassRound(Module):defforward(self,data):returntorch.round(data)verify_model(Round(),input_info)@requires_tensorrtdeftest_tanh():"""test tensorrt translator for tanh"""classTanh(Module):defforward(self,data):returntorch.tanh(data)input_info=[([1,3,10,10],"float32")]verify_model(Tanh(),input_info)@requires_tensorrtdeftest_clamp():"""test tensorrt translator for clamp"""classClamp(Module):defforward(self,data):returntorch.clamp(data,min=0.1,max=0.5)input_info=[([1,3,10,10],"float32")]verify_model(Clamp(),input_info)@requires_tensorrtdeftest_interpolate():"""test tensorrt translator for interpolate"""classInterpolate(Module):defforward(self,data):returntorch.nn.functional.interpolate(data,(5,5))input_info=[([1,3,10,10],"float32")]verify_model(Interpolate(),input_info)@requires_tensorrtdeftest_addmm():"""test tensorrt translator for addmm"""classAddmm(Module):defforward(self,x_1,x_2,x_3):returntorch.addmm(x_1,x_2,x_3)input_info=[([10,10],"float32"),([10,10],"float32"),([10,10],"float32"),]verify_model(Addmm(),input_info)@requires_tensorrtdeftest_split():"""test tensorrt translator for split"""classSplit(Module):defforward(self,data):returntorch.split(data,1,dim=1)input_info=[([1,3,10,10],"float32")]verify_model(Split(),input_info)@requires_tensorrtdeftest_chunk():"""test tensorrt translator for chunk"""classChunk(Module):defforward(self,data):returntorch.chunk(data,3,dim=1)input_info=[([1,3,10,10],"float32")]verify_model(Chunk(),input_info)@requires_tensorrtdeftest_expand():"""test tensorrt translator for expand"""classExpand(Module):defforward(self,x):x=x+1.0returnx.expand(4,2,3,4)input_info=[([1,2,3,4],"float32")]verify_model(Expand(),input_info)@requires_tensorrtdeftest_reduce():"""test tensorrt translator for reduce"""# sumclassSum(Module):defforward(self,x):returntorch.sum(x,(2,1))input_info=[([1,2,3,4],"float32")]verify_model(Sum(),input_info)@requires_tensorrtdeftest_permute():"""test tensorrt translator for permute"""classPermute(Module):defforward(self,x):returnx.permute(0,3,2,1)input_info=[([1,2,3,4],"float32")]verify_model(Permute(),input_info)@requires_tensorrtdeftest_reshape():"""test tensorrt translator for reshape"""classReshape(Module):defforward(self,x):returnx.reshape(2,12)input_info=[([1,2,3,4],"float32")]verify_model(Reshape(),input_info)@requires_tensorrtdeftest_transpose():"""test tensorrt translator for transpose"""classTranspose(Module):defforward(self,x):returnx.transpose(1,3)input_info=[([1,2,3,4],"float32")]verify_model(Transpose(),input_info)@requires_tensorrtdeftest_view():"""test tensorrt translator for view"""classView(Module):defforward(self,x):returnx.view(2,12)input_info=[([1,2,3,4],"float32")]verify_model(View(),input_info)@requires_tensorrtdeftest_argmax():"""test tensorrt translator for argmax"""classArgmax1(Module):defforward(self,data):returntorch.argmax(data,dim=-1)classArgmax2(Module):defforward(self,data):returntorch.argmax(data,dim=-1,keepdim=True)verify_model(Argmax1(),[([256,256],"float32")],allow_incomplete=True)verify_model(Argmax2(),[([256,256],"float32")],allow_incomplete=True)@requires_tensorrtdeftest_argmin():"""test tensorrt translator for argmin"""classArgmin1(Module):defforward(self,data):returntorch.argmin(data,dim=-1)classArgmin2(Module):defforward(self,data):returntorch.argmin(data,dim=-1,keepdim=True)verify_model(Argmin1(),[([256,256],"float32")],allow_incomplete=True)verify_model(Argmin2(),[([256,256],"float32")],allow_incomplete=True)@requires_tensorrtdeftest_mean():"""test tensorrt translator for mean"""classMean(Module):defforward(self,data):returndata.mean(-1)classMeanKeepDim(Module):defforward(self,data):returndata.mean(-1,keepdim=True)verify_model(Mean(),[([256,256],"float32")])verify_model(MeanKeepDim(),[([256,256],"float32")])@requires_tensorrtdeftest_rsqrt():"""test tensorrt translator for rsqrt"""classRsqrt(Module):defforward(self,data):returntorch.rsqrt(data)verify_model(Rsqrt(),[([256,256],"float32")])@requires_tensorrtdeftest_neg():"""test tensorrt translator for neg"""classNeg(Module):defforward(self,data):return-dataverify_model(Neg(),[([256,256],"float32")])@requires_tensorrtdeftest_max():"""test tensorrt translator for max"""classMax(Module):defforward(self,x,y):returntorch.max(x,y)verify_model(Max(),[([256,256],"float32"),([256,256],"float32")])if__name__=="__main__":tvm.testing.main()