# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.""" Test translate from relax. """importpytestimporttorchfromtorchimportfxfromtorch.nnimportModuleimportnumpyasnpimporttvm.testingfromtvm.relax.frontend.torchimportfrom_fxfromtvm.contrib.msc.core.frontendimporttranslatefromtvm.contrib.msc.framework.tvmimportcodegenastvm_codegendef_verify_model(torch_model,input_info,opt_config=None):graph_model=fx.symbolic_trace(torch_model)withtorch.no_grad():orig_mod=from_fx(graph_model,input_info)target="llvm"dev=tvm.cpu()args=[tvm.nd.array(np.random.random(size=shape).astype(dtype))forshape,dtypeininput_info]def_tvm_runtime_to_np(obj):ifisinstance(obj,tvm.runtime.NDArray):returnobj.numpy()elifisinstance(obj,tvm.runtime.ShapeTuple):returnnp.array(obj,dtype="int64")elifisinstance(obj,(list,tvm.ir.container.Array)):return[_tvm_runtime_to_np(item)foriteminobj]elifisinstance(obj,tuple):returntuple(_tvm_runtime_to_np(item)foriteminobj)else:returnobjdef_run_relax(relax_mod):relax_mod=tvm.relax.transform.LegalizeOps()(relax_mod)relax_exec=tvm.relax.build(relax_mod,target)vm_runner=tvm.relax.VirtualMachine(relax_exec,dev)res=vm_runner["main"](*args)return_tvm_runtime_to_np(res)rt_mod=tvm_codegen.to_relax(*translate.from_relax(orig_mod,opt_config=opt_config),codegen_config={"explicit_name":False},)orig_output=_run_relax(orig_mod)rt_output=_run_relax(rt_mod)tvm.testing.assert_allclose(orig_output,rt_output)deftest_conv1d():"""test relax translator for conv1d"""classConv1D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv1D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10],"float32")]_verify_model(Conv1D1(),input_info)_verify_model(Conv1D2(),input_info)deftest_conv2d():"""test relax translator for conv2d"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv2D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10,10],"float32")]_verify_model(Conv2D1(),input_info)_verify_model(Conv2D2(),input_info)deftest_linear():"""test relax translator for linear"""classDense1(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=True)defforward(self,data):returnself.linear(data)classDense2(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=False)defforward(self,data):returnself.linear(data)classMatMul1(Module):defforward(self,x,y):returntorch.matmul(x,y)input_info=[([1,3,10,10],"float32")]_verify_model(Dense1(),input_info)_verify_model(Dense2(),input_info)_verify_model(MatMul1(),[([10,10],"float32"),([10,10],"float32")])deftest_bmm():"""test relax translator for bmm"""classBMM(Module):defforward(self,x,y):returntorch.bmm(x,y)input_info=[((4,128,256),"float32"),((4,256,512),"float32")]_verify_model(BMM(),input_info)deftest_baddbmm():"""test relax translator for baddbmm"""classBAddBMM1(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y)classBAddBMM2(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y,alpha=2,beta=0)input_info=[((4,128,512),"float32"),((4,128,256),"float32"),((4,256,512),"float32"),]_verify_model(BAddBMM1(),input_info)_verify_model(BAddBMM2(),input_info)deftest_relu():"""test relax translator for relu"""classReLU(Module):def__init__(self):super().__init__()self.relu=torch.nn.ReLU()defforward(self,data):returnself.relu(data)classReLU1(Module):defforward(self,data):returntorch.nn.functional.relu(data)input_info=[([10,10],"float32")]_verify_model(ReLU(),input_info)_verify_model(ReLU1(),input_info)deftest_relu6():"""test relax translator for relu6"""classReLU6(Module):def__init__(self):super().__init__()self.relu6=torch.nn.ReLU6()defforward(self,data):returnself.relu6(data)input_info=[([10,10],"float32")]_verify_model(ReLU6(),input_info)deftest_maxpool2d():"""test relax translator for maxpool2d"""classMaxPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classMaxPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[2,2],dilation=[2,3])defforward(self,data):returnself.pool(data)classMaxPool2d3(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[4,4],padding=2,stride=2)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]_verify_model(MaxPool2d(),input_info)_verify_model(MaxPool2d2(),input_info)_verify_model(MaxPool2d3(),input_info)deftest_avgpool2d():"""test relax translator for avgpool2d"""classAvgPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classAvgPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[4,4],stride=2,padding=2,ceil_mode=True)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]_verify_model(AvgPool2d(),input_info)_verify_model(AvgPool2d2(),input_info)deftest_adaptive_avgpool2d():"""test relax translator for adaptive_avgpool2d"""classAdaptiveAvgPool2d0(Module):def__init__(self):super().__init__()self.pool=torch.nn.AdaptiveAvgPool2d([10,10])defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]_verify_model(AdaptiveAvgPool2d0(),input_info)deftest_flatten():"""test relax translator for flatten"""classFlatten(Module):def__init__(self):super().__init__()self.f=torch.nn.Flatten(2,-1)defforward(self,data):returnself.f(data)input_info=[([1,3,10,10],"float32")]_verify_model(Flatten(),input_info)_verify_model(torch.nn.Flatten(2,-1),input_info)deftest_batchnorm2d():"""test relax translator for batchnorm2d"""classBatchNorm2d(Module):def__init__(self):super().__init__()self.batchnorm=torch.nn.BatchNorm2d(3)defforward(self,data):returnself.batchnorm(data)input_info=[([1,3,10,10],"float32")]_verify_model(BatchNorm2d(),input_info)deftest_embedding():"""test relax translator for embedding"""classEmbedding(Module):def__init__(self):super().__init__()self.embedding=torch.nn.Embedding(10,3)defforward(self,data):returnself.embedding(data)_verify_model(Embedding(),[([4],"int64")])_verify_model(Embedding(),[([4,5],"int64")])deftest_dropout():"""test relax translator for dropout"""classDropout1(Module):def__init__(self):super().__init__()self.dropout=torch.nn.Dropout(0.5)defforward(self,data):returnself.dropout(data)classDropout2(Module):defforward(self,data):returntorch.dropout(data,0.5,train=True)input_info=[([1,3,10,10],"float32")]_verify_model(Dropout1(),input_info)_verify_model(Dropout2(),input_info)deftest_layernorm():"""test relax translator for layernorm"""classLayerNorm(Module):def__init__(self):super().__init__()self.layernorm=torch.nn.LayerNorm((10,10))defforward(self,data):returnself.layernorm(data)input_info=[([1,3,10,10],"float32")]_verify_model(LayerNorm(),input_info)deftest_functional_layernorm():"""test relax translator for functional_layernorm"""classLayerNorm(Module):def__init__(self,shape):super().__init__()self.weight=torch.nn.Parameter(torch.ones(shape))self.bias=torch.nn.Parameter(torch.zeros(shape))defforward(self,data):returntorch.nn.functional.layer_norm(data,self.weight.shape,self.weight,self.bias,1e-5)input_info=[([1,3,10,10],"float32")]_verify_model(LayerNorm((10,10)),input_info)deftest_cross_entropy():"""test relax translator for cross_entropy"""classCrossEntropy1(Module):def__init__(self):super().__init__()self.loss=torch.nn.CrossEntropyLoss()defforward(self,logits,targets):returnself.loss(logits,targets)classCrossEntropy2(Module):def__init__(self):super().__init__()self.weight=torch.nn.Parameter(torch.ones((2,)))self.loss=torch.nn.CrossEntropyLoss(weight=self.weight)defforward(self,logits,targets):returnself.loss(logits,targets)classCrossEntropy3(Module):def__init__(self):super().__init__()self.loss=torch.nn.CrossEntropyLoss(ignore_index=1,reduction="sum")defforward(self,logits,targets):returnself.loss(logits,targets)input_info=[([3,2],"float32"),([3],"int32")]_verify_model(CrossEntropy1(),input_info)_verify_model(CrossEntropy2(),input_info)_verify_model(CrossEntropy3(),input_info)deftest_functional_cross_entropy():"""test relax translator for functional_cross_entropy"""classCrossEntropy(Module):defforward(self,logits,targets):returntorch.nn.functional.cross_entropy(logits,targets)input_info=[([3,10],"float32"),([3],"int32")]_verify_model(CrossEntropy(),input_info)deftest_silu():"""test relax translator for silu"""classSiLU(Module):def__init__(self):super().__init__()self.silu=torch.nn.SiLU()defforward(self,data):returnself.silu(data)classSiLU2(Module):defforward(self,data):returntorch.nn.functional.silu(data)input_info=[([1,3,10,10],"float32")]_verify_model(SiLU(),input_info)_verify_model(SiLU2(),input_info)deftest_groupnorm():"""test relax translator for groupnorm"""classGroupNorm(Module):def__init__(self):super().__init__()self.groupnorm=torch.nn.GroupNorm(3,3)defforward(self,data):returnself.groupnorm(data)input_info=[([1,3,10,10],"float32")]_verify_model(GroupNorm(),input_info)deftest_softmax():"""test relax translator for softmax"""classSoftmax(Module):def__init__(self):super().__init__()self.softmax=torch.nn.Softmax(dim=1)defforward(self,data):returnself.softmax(data)input_info=[([1,3,10,10],"float32")]_verify_model(Softmax(),input_info)deftest_binary():"""test relax translator for binary"""input_info1=[([1,3,10,10],"float32"),([1,3,10,10],"float32")]input_info2=[([1,3,10,10],"float32")]# AddclassAdd1(Module):defforward(self,lhs,rhs):returnlhs+rhsclassAdd2(Module):defforward(self,lhs):returnlhs+1.0_verify_model(Add1(),input_info1)_verify_model(Add2(),input_info2)# SubclassSub1(Module):defforward(self,lhs,rhs):returnlhs-rhsclassSub2(Module):defforward(self,lhs):returnlhs-1.0_verify_model(Sub1(),input_info1)_verify_model(Sub2(),input_info2)# MulclassMul1(Module):defforward(self,lhs,rhs):returnlhs*rhsclassMul2(Module):defforward(self,lhs):returnlhs*1.0_verify_model(Mul1(),input_info1)_verify_model(Mul2(),input_info2)# True divclassTrueDiv1(Module):defforward(self,lhs,rhs):returnlhs/rhsclassTrueDiv2(Module):defforward(self,lhs):returnlhs/1.0_verify_model(TrueDiv1(),input_info1)_verify_model(TrueDiv2(),input_info2)# Floor divclassFloorDiv1(Module):defforward(self,lhs,rhs):returnlhs//rhsclassFloorDiv2(Module):defforward(self,lhs):returnlhs//1.0_verify_model(FloorDiv1(),input_info1)_verify_model(FloorDiv2(),input_info2)# PowerclassPower1(Module):defforward(self,lhs,rhs):returnlhs**rhsclassPower2(Module):defforward(self,lhs):returnlhs**1.0_verify_model(Power1(),input_info1)_verify_model(Power2(),input_info2)# LTclassLT1(Module):defforward(self,lhs,rhs):returnlhs<rhsclassLT2(Module):defforward(self,lhs):returnlhs<1.0_verify_model(LT1(),input_info1)_verify_model(LT2(),input_info2)deftest_size():"""test relax translator for size"""classSize(Module):defforward(self,data):returndata.size()input_info=[([1,3,10,10],"float32")]_verify_model(Size(),input_info)deftest_squeeze():"""test relax translator for squeeze"""classSqueeze1(Module):defforward(self,data):returndata.squeeze(1)classSqueeze2(Module):defforward(self,data):returndata.squeeze()input_info=[([3,1,4,1],"float32")]_verify_model(Squeeze1(),input_info)_verify_model(Squeeze2(),input_info)deftest_unsqueeze():"""test relax translator for unsqueeze"""classUnsqueeze1(Module):defforward(self,data):returndata.unsqueeze(1)classUnsqueeze2(Module):defforward(self,data):returndata.unsqueeze(-1)input_info=[([1,3,10,10],"float32")]_verify_model(Unsqueeze1(),input_info)_verify_model(Unsqueeze2(),input_info)deftest_getattr():"""test relax translator for getattr"""classGetAttr1(Module):defforward(self,data):returndata.shapeinput_info=[([1,3,10,10],"float32")]_verify_model(GetAttr1(),input_info)@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")deftest_getitem():"""test relax translator for getitem"""classSlice1(Module):defforward(self,x):returnx[0,1::2,:,:3]classSlice2(Module):defforward(self,x):returnx[:,None,None,:,None]_verify_model(Slice1(),[([1,3,10,10],"float32")])_verify_model(Slice2(),[([8,16],"float32")])deftest_unary():"""test relax translator for unary"""input_info=[([1,3,10,10],"float32")]# sinclassSin(Module):defforward(self,data):returntorch.sin(data)_verify_model(Sin(),input_info)# cosclassCos(Module):defforward(self,data):returntorch.cos(data)_verify_model(Cos(),input_info)# expclassExp(Module):defforward(self,data):returntorch.exp(data)_verify_model(Exp(),input_info)# sqrtclassSqrt(Module):defforward(self,data):returntorch.sqrt(data)_verify_model(Sqrt(),input_info)# sigmoidclassSigmoid(Module):defforward(self,data):returntorch.sigmoid(data)_verify_model(Sigmoid(),input_info)# roundclassRound(Module):defforward(self,data):returntorch.round(data)_verify_model(Round(),input_info)deftest_gelu():"""test relax translator for gelu"""classGelu(Module):defforward(self,data):returntorch.nn.functional.gelu(data)input_info=[([1,3,10,10],"float32")]_verify_model(Gelu(),input_info)deftest_tanh():"""test relax translator for tanh"""classTanh(Module):defforward(self,data):returntorch.tanh(data)input_info=[([1,3,10,10],"float32")]_verify_model(Tanh(),input_info)deftest_clamp():"""test relax translator for clamp"""classClamp(Module):defforward(self,data):returntorch.clamp(data,min=0.1,max=0.5)input_info=[([1,3,10,10],"float32")]_verify_model(Clamp(),input_info)deftest_interpolate():"""test relax translator for interpolate"""classInterpolate(Module):defforward(self,data):returntorch.nn.functional.interpolate(data,(5,5))input_info=[([1,3,10,10],"float32")]_verify_model(Interpolate(),input_info)deftest_addmm():"""test relax translator for addmm"""classAddmm(Module):defforward(self,x_1,x_2,x_3):returntorch.addmm(x_1,x_2,x_3)input_info=[([10,10],"float32"),([10,10],"float32"),([10,10],"float32"),]_verify_model(Addmm(),input_info)deftest_split():"""test relax translator for split"""classSplit(Module):defforward(self,data):returntorch.split(data,1,dim=1)input_info=[([1,3,10,10],"float32")]_verify_model(Split(),input_info)deftest_cumsum():"""test relax translator for cumsum"""classCumsum(Module):defforward(self,data):returntorch.cumsum(data,dim=1,dtype=torch.int32)input_info=[([1,2,3,4],"float32")]_verify_model(Cumsum(),input_info)deftest_chunk():"""test relax translator for chunk"""classChunk(Module):defforward(self,data):returntorch.chunk(data,3,dim=1)input_info=[([1,3,10,10],"float32")]_verify_model(Chunk(),input_info)deftest_inplace_fill():"""test relax translator for inplace_fill"""classInplaceFill(Module):defforward(self,data):data.fill_(1.5)returndata_verify_model(InplaceFill(),[([10,10],"float32")],opt_config={"opt_level":0})deftest_arange():"""test relax translator for arange"""classArange(Module):defforward(self):returntorch.arange(0,20,dtype=torch.int32)_verify_model(Arange(),[([10,10],"float32")])deftest_empty():"""test relax translator for empty"""classEmpty(Module):defforward(self):returntorch.empty((10,10),dtype=torch.float32)_verify_model(Empty(),[([10,10],"float32")])deftest_tensor():"""test relax translator for tensor"""classEmpty1(Module):defforward(self):returntorch.tensor(3,dtype=torch.float32)classEmpty2(Module):defforward(self):returntorch.tensor(3)_verify_model(Empty1(),[([10,10],"float32")])_verify_model(Empty2(),[([10,10],"float32")])deftest_tril():"""test relax translator for tril"""classTril(Module):defforward(self,data):returntorch.tril(data,1)classInplaceTril(Module):defforward(self,data):data.tril_(1)returndatainput_info=[([10,10],"float32")]_verify_model(Tril(),input_info)_verify_model(InplaceTril(),input_info)deftest_triu():"""test relax translator for triu"""classTriu(Module):defforward(self,data):returntorch.triu(data,1)classInplaceTriu(Module):defforward(self,data):data.triu_(1)returndatainput_info=[([10,10],"float32")]_verify_model(Triu(),input_info)_verify_model(InplaceTriu(),input_info)deftest_new_ones():"""test relax translator for new_ones"""classNewOnes(Module):defforward(self,x):returnx.new_ones(1,2,3)input_info=[([1,2,3],"float32")]_verify_model(NewOnes(),input_info,opt_config={"opt_level":0})deftest_expand():"""test relax translator for expand"""classExpand(Module):defforward(self,x):returnx.expand(4,2,3,4)input_info=[([1,2,3,4],"float32")]_verify_model(Expand(),input_info)deftest_reduce():"""test relax translator for reduce"""# sumclassSum(Module):defforward(self,x):returntorch.sum(x,(2,1))input_info=[([1,2,3,4],"float32")]_verify_model(Sum(),input_info)deftest_datatype():"""test relax translator for datatype"""input_info=[([1,2,3,4],"float32")]# floatclassToFloat(Module):defforward(self,x):returnx.float()_verify_model(ToFloat(),input_info)# halfclassToHalf(Module):defforward(self,x):returnx.half()_verify_model(ToHalf(),input_info)# typeclassType(Module):defforward(self,x):returnx.type(torch.float32)# typeclassTypeFromAttr(Module):defforward(self,x):returnx.type(x.getattr("dtype"))# astypeclassAsType(Module):defforward(self,x):returnx.astype(torch.float32)_verify_model(Type(),input_info)_verify_model(TypeFromAttr(),input_info)_verify_model(AsType(),input_info)deftest_permute():"""test relax translator for permute"""classPermute(Module):defforward(self,x):returnx.permute(0,3,2,1)input_info=[([1,2,3,4],"float32")]_verify_model(Permute(),input_info)deftest_reshape():"""test relax translator for reshape"""classReshape(Module):defforward(self,x):returnx.reshape(2,12)input_info=[([1,2,3,4],"float32")]_verify_model(Reshape(),input_info)deftest_transpose():"""test relax translator for transpose"""classTranspose(Module):defforward(self,x):returnx.transpose(1,3)input_info=[([1,2,3,4],"float32")]_verify_model(Transpose(),input_info)deftest_view():"""test relax translator for view"""classView(Module):defforward(self,x):returnx.view(2,12)input_info=[([1,2,3,4],"float32")]_verify_model(View(),input_info)deftest_keep_params():"""test relax translator for keep_params"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)_verify_model(Conv2D1(),[([1,3,10,10],"float32")])deftest_unwrap_unit_return_tuple():"""test relax translator for unwrap_unit_return_tuple"""classIdentity(Module):defforward(self,x):return(x,)_verify_model(Identity(),[([256,256],"float32")])deftest_no_bind_return_tuple():"""test relax translator for no_bind_return_tuple"""classIdentity(Module):defforward(self,x,y):return(x,y)input_info=[([256,256],"float32"),([256,256],"float32")]_verify_model(Identity(),input_info)deftest_argmax():"""test relax translator for argmax"""classArgmax1(Module):defforward(self,data):returntorch.argmax(data,dim=-1)classArgmax2(Module):defforward(self,data):returntorch.argmax(data,dim=-1,keepdim=True)_verify_model(Argmax1(),[([256,256],"float32")])_verify_model(Argmax2(),[([256,256],"float32")])deftest_argmin():"""test relax translator for argmin"""classArgmin1(Module):defforward(self,data):returntorch.argmin(data)classArgmin2(Module):defforward(self,data):returntorch.argmin(data,keepdim=True)_verify_model(Argmin1(),[([256,256],"float32")])_verify_model(Argmin2(),[([256,256],"float32")])deftest_to():"""test relax translator for to"""classTo1(Module):defforward(self,data):returndata.to(torch.float16)classTo2(Module):defforward(self,data):returndata.to("cpu")_verify_model(To1(),[([256,256],"float32")])_verify_model(To2(),[([256,256],"float32")])deftest_mean():"""test relax translator for mean"""classMean(Module):defforward(self,data):returndata.mean(-1)classMeanKeepDim(Module):defforward(self,data):returndata.mean(-1,keepdim=True)_verify_model(Mean(),[([256,256],"float32")])_verify_model(MeanKeepDim(),[([256,256],"float32")])deftest_rsqrt():"""test relax translator for rsqrt"""classRsqrt(Module):defforward(self,data):returntorch.rsqrt(data)_verify_model(Rsqrt(),[([256,256],"float32")])deftest_neg():"""test relax translator for neg"""classNeg(Module):defforward(self,data):return-data_verify_model(Neg(),[([256,256],"float32")])deftest_max():"""test relax translator for max"""classMax(Module):defforward(self,x,y):returntorch.max(x,y)_verify_model(Max(),[([256,256],"float32"),([256,256],"float32")])deftest_attention():"""test relax translator for attention"""# pylint: disable=import-outside-toplevelimporttorch.nn.functionalasFclassAttention1(Module):defforward(self,q_data,k_data,v_data):returnF.scaled_dot_product_attention(q_data,k_data,v_data)classAttention2(Module):defforward(self,q_data,k_data,v_data):returnF.scaled_dot_product_attention(q_data,k_data,v_data,is_causal=True)input_info=[([32,8,128,64],"float32"),([32,8,128,64],"float32"),([32,8,128,64],"float32"),]_verify_model(Attention1(),input_info)_verify_model(Attention2(),input_info)classAttention3(Module):defforward(self,q_data,k_data,v_data,mask):returnF.scaled_dot_product_attention(q_data,k_data,v_data,mask)_verify_model(Attention3(),[([32,8,128,64],"float32"),([32,8,128,64],"float32"),([32,8,128,64],"float32"),([32,8,128,128],"float32"),],)if__name__=="__main__":tvm.testing.main()