defverify_model(torch_model,input_info,via_relax=True):"""Compare torch module results"""graph,weights=translate.from_torch(torch_model,input_info,via_relax=via_relax)model=codegen.to_torch(graph,weights)# print(graph)torch_datas=[torch.from_numpy(np.random.rand(*i[0]).astype(i[1]))foriininput_info]withtorch.no_grad():golden=torch_model(*torch_datas)withtorch.no_grad():ifnotgraph.get_inputs():result=model()else:result=model(*torch_datas)ifnotisinstance(golden,(list,tuple)):golden=[golden]ifnotisinstance(result,(list,tuple)):result=[result]assertlen(golden)==len(result),"golden {} mismatch with result {}".format(len(golden),len(result))forgol_r,new_rinzip(golden,result):ifisinstance(gol_r,torch.Tensor):tvm.testing.assert_allclose(gol_r.detach().numpy(),new_r.detach().numpy(),atol=1e-5,rtol=1e-5)else:assertgol_r==new_r
deftest_conv1d():"""test torch translator for conv1d"""classConv1D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv1D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10],"float32")]forvia_relaxin[True,False]:verify_model(Conv1D1(),input_info,via_relax)verify_model(Conv1D2(),input_info,via_relax)
test_conv1d()
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(folder.relpath(graph.name + ".pth"))
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(folder.relpath(graph.name + ".pth"))
deftest_conv2d():"""test torch translator for conv2d"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv2D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Conv2D1(),input_info,via_relax)verify_model(Conv2D2(),input_info,via_relax)
test_conv2d()
deftest_linear():"""test torch translator for linear"""classDense1(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=True)defforward(self,data):returnself.linear(data)classDense2(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=False)defforward(self,data):returnself.linear(data)classMatMul1(Module):defforward(self,x,y):returntorch.matmul(x,y)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Dense1(),input_info,via_relax)verify_model(Dense2(),input_info,via_relax)verify_model(MatMul1(),[([10,10],"float32"),([10,10],"float32")],via_relax)deftest_bmm():"""test torch translator for bmm"""classBMM(Module):defforward(self,x,y):returntorch.bmm(x,y)input_info=[((4,128,256),"float32"),((4,256,512),"float32")]forvia_relaxin[True,False]:verify_model(BMM(),input_info,via_relax)deftest_baddbmm():"""test torch translator for baddbmm"""classBAddBMM1(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y)classBAddBMM2(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y,alpha=2,beta=0)input_info=[((4,128,512),"float32"),((4,128,256),"float32"),((4,256,512),"float32"),]forvia_relaxin[True,False]:verify_model(BAddBMM1(),input_info,via_relax)verify_model(BAddBMM2(),input_info,via_relax)deftest_relu():"""test torch translator for relu"""classReLU(Module):def__init__(self):super().__init__()self.relu=torch.nn.ReLU()defforward(self,data):returnself.relu(data)classReLU1(Module):defforward(self,data):returntorch.nn.functional.relu(data)input_info=[([10,10],"float32")]forvia_relaxin[True,False]:verify_model(ReLU(),input_info,via_relax)verify_model(ReLU1(),input_info,via_relax)deftest_relu6():"""test torch translator for relu6"""classReLU6(Module):def__init__(self):super().__init__()self.relu6=torch.nn.ReLU6()defforward(self,data):returnself.relu6(data)input_info=[([10,10],"float32")]forvia_relaxin[True,False]:verify_model(ReLU6(),input_info,via_relax)deftest_maxpool2d():"""test torch translator for maxpool2d"""classMaxPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classMaxPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[2,2],dilation=[2,3])defforward(self,data):returnself.pool(data)classMaxPool2d3(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[4,4],padding=2,stride=2)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(MaxPool2d(),input_info,via_relax)verify_model(MaxPool2d2(),input_info,via_relax)verify_model(MaxPool2d3(),input_info,via_relax)deftest_avgpool2d():"""test torch translator for avgpool2d"""classAvgPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classAvgPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[4,4],stride=2,padding=2,ceil_mode=True)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(AvgPool2d(),input_info,via_relax)verify_model(AvgPool2d2(),input_info,via_relax)deftest_adaptive_avgpool2d():"""test torch translator for adaptive_avgpool2d"""classAdaptiveAvgPool2d0(Module):def__init__(self):super().__init__()self.pool=torch.nn.AdaptiveAvgPool2d([10,10])defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(AdaptiveAvgPool2d0(),input_info,via_relax)deftest_flatten():"""test torch translator for flatten"""classFlatten(Module):def__init__(self):super().__init__()self.f=torch.nn.Flatten(2,-1)defforward(self,data):returnself.f(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Flatten(),input_info,via_relax)verify_model(torch.nn.Flatten(2,-1),input_info,via_relax)deftest_batchnorm2d():"""test torch translator for batchnorm2d"""classBatchNorm2d(Module):def__init__(self):super().__init__()self.batchnorm=torch.nn.BatchNorm2d(3)defforward(self,data):returnself.batchnorm(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(BatchNorm2d(),input_info,via_relax)deftest_embedding():"""test torch translator for embedding"""classEmbedding(Module):def__init__(self):super().__init__()self.embedding=torch.nn.Embedding(10,3)defforward(self,data):returnself.embedding(data)forvia_relaxin[True,False]:verify_model(Embedding(),[([4],"int64")],via_relax)verify_model(Embedding(),[([4,5],"int64")],via_relax)deftest_layernorm():"""test torch translator for layernorm"""classLayerNorm(Module):def__init__(self):super().__init__()self.layernorm=torch.nn.LayerNorm((10,10))defforward(self,data):returnself.layernorm(data)input_info=[([1,3,10,10],"float32")]verify_model(LayerNorm(),input_info)deftest_cross_entropy():"""test torch translator for cross_entropy"""classCrossEntropy1(Module):def__init__(self):super().__init__()self.loss=torch.nn.CrossEntropyLoss()defforward(self,logits,targets):returnself.loss(logits,targets)classCrossEntropy2(Module):def__init__(self):super().__init__()self.weight=torch.nn.Parameter(torch.ones((2,)))self.loss=torch.nn.CrossEntropyLoss(weight=self.weight)defforward(self,logits,targets):returnself.loss(logits,targets)classCrossEntropy3(Module):def__init__(self):super().__init__()self.loss=torch.nn.CrossEntropyLoss(ignore_index=1,reduction="sum")defforward(self,logits,targets):returnself.loss(logits,targets)input_info=[([3,2],"float32"),([3],"int64")]forvia_relaxin[True,False]:verify_model(CrossEntropy1(),input_info,via_relax)verify_model(CrossEntropy2(),input_info,via_relax)verify_model(CrossEntropy3(),input_info,via_relax)deftest_silu():"""test torch translator for silu"""classSiLU(Module):def__init__(self):super().__init__()self.silu=torch.nn.SiLU()defforward(self,data):returnself.silu(data)classSiLU2(Module):defforward(self,data):returntorch.nn.functional.silu(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(SiLU(),input_info,via_relax)verify_model(SiLU2(),input_info,via_relax)deftest_groupnorm():"""test torch translator for groupnorm"""classGroupNorm(Module):def__init__(self):super().__init__()self.groupnorm=torch.nn.GroupNorm(3,3)defforward(self,data):returnself.groupnorm(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(GroupNorm(),input_info,via_relax)deftest_softmax():"""test torch translator for softmax"""classSoftmax(Module):def__init__(self):super().__init__()self.softmax=torch.nn.Softmax(dim=1)defforward(self,data):returnself.softmax(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Softmax(),input_info,via_relax)deftest_binary():"""test torch translator for binary"""input_info1=[([1,3,10,10],"float32"),([1,3,10,10],"float32")]input_info2=[([1,3,10,10],"float32")]# AddclassAdd1(Module):defforward(self,lhs,rhs):returnlhs+rhsclassAdd2(Module):defforward(self,lhs):returnlhs+1.0forvia_relaxin[True,False]:verify_model(Add1(),input_info1,via_relax)verify_model(Add2(),input_info2,via_relax)# SubclassSub1(Module):defforward(self,lhs,rhs):returnlhs-rhsclassSub2(Module):defforward(self,lhs):returnlhs-1.0forvia_relaxin[True,False]:verify_model(Sub1(),input_info1,via_relax)verify_model(Sub2(),input_info2,via_relax)# MulclassMul1(Module):defforward(self,lhs,rhs):returnlhs*rhsclassMul2(Module):defforward(self,lhs):returnlhs*1.0forvia_relaxin[True,False]:verify_model(Mul1(),input_info1,via_relax)verify_model(Mul2(),input_info2,via_relax)# True divclassTrueDiv1(Module):defforward(self,lhs,rhs):returnlhs/rhsclassTrueDiv2(Module):defforward(self,lhs):returnlhs/1.0forvia_relaxin[True,False]:verify_model(TrueDiv1(),input_info1,via_relax)verify_model(TrueDiv2(),input_info2,via_relax)# Floor divclassFloorDiv1(Module):defforward(self,lhs,rhs):returnlhs//rhsclassFloorDiv2(Module):defforward(self,lhs):returnlhs//1.0forvia_relaxin[True,False]:verify_model(FloorDiv1(),input_info1,via_relax)verify_model(FloorDiv2(),input_info2,via_relax)# PowerclassPower1(Module):defforward(self,lhs,rhs):returnlhs**rhsclassPower2(Module):defforward(self,lhs):returnlhs**1.0forvia_relaxin[True,False]:verify_model(Power1(),input_info1,via_relax)verify_model(Power2(),input_info2,via_relax)# LTclassLT1(Module):defforward(self,lhs,rhs):returnlhs<rhsclassLT2(Module):defforward(self,lhs):returnlhs<1.0forvia_relaxin[True,False]:verify_model(LT1(),input_info1,via_relax)verify_model(LT2(),input_info2,via_relax)deftest_size():"""test torch translator for size"""classSize(Module):defforward(self,data):returndata.size()input_info=[([1,3,10,10],"float32")]verify_model(Size(),input_info)deftest_squeeze():"""test torch translator for squeeze"""classSqueeze1(Module):defforward(self,data):returndata.squeeze(1)classSqueeze2(Module):defforward(self,data):returndata.squeeze()input_info=[([3,1,4,1],"float32")]forvia_relaxin[True,False]:verify_model(Squeeze1(),input_info,via_relax)verify_model(Squeeze2(),input_info,via_relax)deftest_unsqueeze():"""test torch translator for unsqueeze"""classUnsqueeze1(Module):defforward(self,data):returndata.unsqueeze(1)classUnsqueeze2(Module):defforward(self,data):returndata.unsqueeze(-1)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Unsqueeze1(),input_info,via_relax)verify_model(Unsqueeze2(),input_info,via_relax)deftest_getattr():"""test torch translator for getattr"""classGetAttr1(Module):defforward(self,data):returndata.shapeinput_info=[([1,3,10,10],"float32")]verify_model(GetAttr1(),input_info)@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")deftest_getitem():"""test torch translator for getitem"""# TODO(tong.meng): strided_slice reshape bug for x[0, 1::2, :, :3]classSlice1(Module):defforward(self,x):returnx[0:1,1::2,:,:3]classSlice2(Module):defforward(self,x):returnx[:,None,None,:,None]forvia_relaxin[True,False]:verify_model(Slice1(),[([1,3,10,10],"float32")],via_relax)verify_model(Slice2(),[([8,16],"float32")],via_relax)deftest_unary():"""test torch translator for unary"""input_info=[([1,3,10,10],"float32")]# sinclassSin(Module):defforward(self,data):returntorch.sin(data)forvia_relaxin[True,False]:verify_model(Sin(),input_info,via_relax)# cosclassCos(Module):defforward(self,data):returntorch.cos(data)forvia_relaxin[True,False]:verify_model(Cos(),input_info,via_relax)# expclassExp(Module):defforward(self,data):returntorch.exp(data)forvia_relaxin[True,False]:verify_model(Exp(),input_info,via_relax)# sqrtclassSqrt(Module):defforward(self,data):returntorch.sqrt(data)forvia_relaxin[True,False]:verify_model(Sqrt(),input_info,via_relax)# sigmoidclassSigmoid(Module):defforward(self,data):returntorch.sigmoid(data)forvia_relaxin[True,False]:verify_model(Sigmoid(),input_info,via_relax)# roundclassRound(Module):defforward(self,data):returntorch.round(data)forvia_relaxin[True,False]:verify_model(Round(),input_info,via_relax)deftest_gelu():"""test torch translator for gelu"""classGelu(Module):defforward(self,data):returntorch.nn.functional.gelu(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Gelu(),input_info,via_relax)deftest_tanh():"""test torch translator for tanh"""classTanh(Module):defforward(self,data):returntorch.tanh(data)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Tanh(),input_info,via_relax)deftest_clamp():"""test torch translator for clamp"""classClamp(Module):defforward(self,data):returntorch.clamp(data,min=0.1,max=0.5)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Clamp(),input_info,via_relax)deftest_interpolate():"""test torch translator for interpolate"""classInterpolate(Module):defforward(self,data):returntorch.nn.functional.interpolate(data,(5,5))input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Interpolate(),input_info,via_relax)deftest_addmm():"""test torch translator for addmm"""classAddmm(Module):defforward(self,x_1,x_2,x_3):returntorch.addmm(x_1,x_2,x_3)input_info=[([10,10],"float32"),([10,10],"float32"),([10,10],"float32"),]forvia_relaxin[True,False]:verify_model(Addmm(),input_info,via_relax)deftest_split():"""test torch translator for split"""classSplit(Module):defforward(self,data):returntorch.split(data,1,dim=1)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Split(),input_info,via_relax)deftest_cumsum():"""test torch translator for cumsum"""classCumsum(Module):defforward(self,data):returntorch.cumsum(data,dim=1,dtype=torch.int32)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(Cumsum(),input_info,via_relax)deftest_chunk():"""test torch translator for chunk"""classChunk(Module):defforward(self,data):returntorch.chunk(data,3,dim=1)input_info=[([1,3,10,10],"float32")]forvia_relaxin[True,False]:verify_model(Chunk(),input_info,via_relax)deftest_inplace_fill():"""test torch translator for inplace_fill"""classInplaceFill(Module):defforward(self,data):data.fill_(1.5)returndataforvia_relaxin[True,False]:verify_model(InplaceFill(),[([10,10],"float32")],via_relax)deftest_arange():"""test torch translator for arange"""# pylint: disable=unused-argumentclassArange(Module):defforward(self,data):returntorch.arange(0,20,dtype=torch.int32)verify_model(Arange(),[([10,10],"float32")])deftest_tril():"""test torch translator for tril"""classTril(Module):defforward(self,data):returntorch.tril(data,1)classInplaceTril(Module):defforward(self,data):data.tril_(1)returndatainput_info=[([10,10],"float32")]forvia_relaxin[True,False]:verify_model(Tril(),input_info,via_relax)verify_model(InplaceTril(),input_info,via_relax)deftest_triu():"""test torch translator for triu"""classTriu(Module):defforward(self,data):returntorch.triu(data,1)classInplaceTriu(Module):defforward(self,data):data.triu_(1)returndatainput_info=[([10,10],"float32")]forvia_relaxin[True,False]:verify_model(Triu(),input_info,via_relax)verify_model(InplaceTriu(),input_info,via_relax)deftest_new_ones():"""test torch translator for new_ones"""classNewOnes(Module):defforward(self,x):returnx.new_ones(1,2,3)input_info=[([1,2,3],"float32")]forvia_relaxin[True,False]:verify_model(NewOnes(),input_info,via_relax)deftest_expand():"""test torch translator for expand"""classExpand(Module):defforward(self,x):returnx.expand(4,2,3,4)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(Expand(),input_info,via_relax)deftest_reduce():"""test torch translator for reduce"""# sumclassSum(Module):defforward(self,x):returntorch.sum(x,(2,1))# maxclassMax(Module):defforward(self,x):returntorch.max(x)# minclassMin(Module):defforward(self,x):returntorch.min(x)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(Sum(),input_info,via_relax)verify_model(Max(),input_info,False)verify_model(Min(),input_info,False)deftest_datatype():"""test torch translator for datatype"""input_info=[([1,2,3,4],"float32")]# floatclassToFloat(Module):defforward(self,x):returnx.float()forvia_relaxin[True,False]:verify_model(ToFloat(),input_info,via_relax)# halfclassToHalf(Module):defforward(self,x):returnx.half()forvia_relaxin[True,False]:verify_model(ToHalf(),input_info,via_relax)# typeclassType(Module):defforward(self,x):returnx.type(torch.float32)forvia_relaxin[True,False]:verify_model(Type(),input_info,via_relax)deftest_permute():"""test torch translator for permute"""classPermute(Module):defforward(self,x):returnx.permute(0,3,2,1)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(Permute(),input_info,via_relax)deftest_reshape():"""test torch translator for reshape"""classReshape(Module):defforward(self,x):returnx.reshape(2,12)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(Reshape(),input_info,via_relax)deftest_transpose():"""test torch translator for transpose"""classTranspose(Module):defforward(self,x):returnx.transpose(1,3)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(Transpose(),input_info,via_relax)deftest_view():"""test torch translator for view"""classView(Module):defforward(self,x):returnx.view(2,12)input_info=[([1,2,3,4],"float32")]forvia_relaxin[True,False]:verify_model(View(),input_info,via_relax)deftest_keep_params():"""test torch translator for keep_params"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)forvia_relaxin[True,False]:verify_model(Conv2D1(),[([1,3,10,10],"float32")],via_relax)deftest_unwrap_unit_return_tuple():"""test torch translator for unwrap_unit_return_tuple"""classIdentity(Module):defforward(self,x):return(x,)forvia_relaxin[True,False]:verify_model(Identity(),[([256,256],"float32")],via_relax)deftest_no_bind_return_tuple():"""test torch translator for no_bind_return_tuple"""classIdentity(Module):defforward(self,x,y):return(x,y)input_info=[([256,256],"float32"),([256,256],"float32")]forvia_relaxin[True,False]:verify_model(Identity(),input_info,via_relax)deftest_argmax():"""test torch translator for argmax"""classArgmax1(Module):defforward(self,data):returntorch.argmax(data,dim=-1)classArgmax2(Module):defforward(self,data):returntorch.argmax(data,dim=-1,keepdim=True)forvia_relaxin[True,False]:verify_model(Argmax1(),[([256,256],"float32")],via_relax)verify_model(Argmax2(),[([256,256],"float32")],via_relax)deftest_argmin():"""test torch translator for argmin"""classArgmin1(Module):defforward(self,data):returntorch.argmin(data)classArgmin2(Module):defforward(self,data):returntorch.argmin(data,keepdim=True)verify_model(Argmin1(),[([256,256],"float32")])verify_model(Argmin2(),[([256,256],"float32")])deftest_to():"""test torch translator for to"""classTo1(Module):defforward(self,data):returndata.to(torch.float16)classTo2(Module):defforward(self,data):returndata.to("cpu")forvia_relaxin[True,False]:verify_model(To1(),[([256,256],"float32")],via_relax)verify_model(To2(),[([256,256],"float32")],via_relax)deftest_mean():"""test torch translator for mean"""classMean(Module):defforward(self,data):returndata.mean(-1)classMeanKeepDim(Module):defforward(self,data):returndata.mean(-1,keepdim=True)forvia_relaxin[True,False]:verify_model(Mean(),[([256,256],"float32")],via_relax)verify_model(MeanKeepDim(),[([256,256],"float32")],via_relax)deftest_rsqrt():"""test torch translator for rsqrt"""classRsqrt(Module):defforward(self,data):returntorch.rsqrt(data)forvia_relaxin[True,False]:verify_model(Rsqrt(),[([256,256],"float32")],via_relax)deftest_neg():"""test torch translator for neg"""classNeg(Module):defforward(self,data):return-dataforvia_relaxin[True,False]:verify_model(Neg(),[([256,256],"float32")],via_relax)deftest_max():"""test torch translator for max"""classMax(Module):defforward(self,x,y):returntorch.max(x,y)forvia_relaxin[True,False]:verify_model(Max(),[([256,256],"float32"),([256,256],"float32")],via_relax)deftest_attention():"""test torch translator for attention"""# pylint: disable=import-outside-toplevelimporttorch.nn.functionalasFclassAttention1(Module):defforward(self,q_data,k_data,v_data):returnF.scaled_dot_product_attention(q_data,k_data,v_data)classAttention2(Module):defforward(self,q_data,k_data,v_data):returnF.scaled_dot_product_attention(q_data,k_data,v_data,is_causal=True)input_info=[([32,8,128,64],"float32"),([32,8,128,64],"float32"),([32,8,128,64],"float32"),]verify_model(Attention1(),input_info)verify_model(Attention2(),input_info)classAttention3(Module):defforward(self,q_data,k_data,v_data,mask):returnF.scaled_dot_product_attention(q_data,k_data,v_data,mask)verify_model(Attention3(),[([32,8,128,64],"float32"),([32,8,128,64],"float32"),([32,8,128,64],"float32"),([32,8,128,128],"float32"),],)if__name__=="__main__":tvm.testing.main()