importnumpyasnpimporttorchfromtorchimportfxfromtorch.nnimportModuleimporttvm.testingfromtvm.relax.frontend.torchimportfrom_fxfromtvm.relay.frontendimportfrom_pytorchfromtvmimportrelayfromtvm.ir.moduleimportIRModulefromtvm.contrib.msc.core.frontendimporttranslatefromtvm.contrib.msc.framework.tvmimportcodegenastvm_codegenfromtvm.contrib.msc.coreimportutilsasmsc_utilsdef_valid_target(target):ifnottarget:returntargetiftarget=="ignore":returnNoneiftarget=="cuda"andnottvm.cuda().exist:returnNoneifisinstance(target,str):target=tvm.target.Target(target)returntargetdef_run_relax(relax_mod,target,datas):relax_mod=tvm.relax.transform.LegalizeOps()(relax_mod)withtvm.transform.PassContext(opt_level=3):relax_exec=tvm.relax.build(relax_mod,target)runnable=tvm.relax.VirtualMachine(relax_exec,tvm.cpu())res=runnable["main"](*datas)ifisinstance(res,tvm.runtime.NDArray):return[res.asnumpy()]return[e.asnumpy()foreinres]defverify_model(torch_model,input_info,opt_config=None,codegen_config=None,build_target=None):"""Compare relax with relay"""graph_model=fx.symbolic_trace(torch_model)withtorch.no_grad():expected=from_fx(graph_model,input_info)expected=tvm.relax.transform.CanonicalizeBindings()(expected)# graph from relaydatas=[np.random.rand(*i[0]).astype(i[1])foriininput_info]torch_datas=[torch.from_numpy(i)foriindatas]withtorch.no_grad():scripted_model=torch.jit.trace(torch_model,tuple(torch_datas)).eval()# type: ignoreshape_list=[("input"+str(idx),i)foridx,iinenumerate(input_info)]relay_mod,params=from_pytorch(scripted_model,shape_list)graph,weights=translate.from_relay(relay_mod,params,opt_config=opt_config)# to relaxcodegen_config=codegen_configor{}codegen_config.update({"explicit_name":False,"from_relay":True})mod=tvm_codegen.to_relax(graph,weights,codegen_config)ifbuild_target:build_target=_valid_target(build_target)ifnotbuild_target:returntvm_datas=[tvm.nd.array(i)foriindatas]expected_res=_run_relax(expected,build_target,tvm_datas)ifnotgraph.get_inputs():tvm_datas=[]res=_run_relax(mod,build_target,tvm_datas)forexp_r,new_rinzip(expected_res,res):tvm.testing.assert_allclose(exp_r,new_r,atol=1e-5,rtol=1e-5)else:tvm.ir.assert_structural_equal(mod,expected)
deftest_conv1d():"""test relay to relax for conv1d"""classConv1D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv1D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv1d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10],"float32")]verify_model(Conv1D1(),input_info)verify_model(Conv1D2(),input_info)deftest_conv2d():"""test relay to relax for conv2d"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)classConv2D2(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=False)defforward(self,data):returnself.conv(data)input_info=[([1,3,10,10],"float32")]verify_model(Conv2D1(),input_info)verify_model(Conv2D2(),input_info)deftest_linear():"""test relay to relax for linear"""classDense1(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=True)defforward(self,data):returnself.linear(data)classDense2(Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,7,bias=False)defforward(self,data):returnself.linear(data)classMatMul1(Module):defforward(self,x,y):returntorch.matmul(x,y)input_info=[([1,3,10,10],"float32")]verify_model(Dense1(),input_info,build_target="llvm")verify_model(Dense2(),input_info,build_target="llvm")verify_model(MatMul1(),[([10,10],"float32"),([10,10],"float32")],build_target="llvm")deftest_bmm():"""test relay to relax for bmm"""classBMM(Module):defforward(self,x,y):returntorch.bmm(x,y)input_info=[((4,128,256),"float32"),((4,256,512),"float32")]verify_model(BMM(),input_info,opt_config={"opt_level":3})deftest_baddbmm():"""test relay to relax for baddbmm"""classBAddBMM1(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y)classBAddBMM2(Module):defforward(self,c,x,y):returntorch.baddbmm(c,x,y,alpha=2,beta=0)input_info=[((4,128,512),"float32"),((4,128,256),"float32"),((4,256,512),"float32"),]verify_model(BAddBMM1(),input_info,opt_config={"opt_level":3},build_target="llvm")verify_model(BAddBMM2(),input_info,opt_config={"opt_level":3},build_target="llvm")deftest_relu():"""test relay to relax for relu"""classReLU(Module):def__init__(self):super().__init__()self.relu=torch.nn.ReLU()defforward(self,data):returnself.relu(data)classReLU1(Module):defforward(self,data):returntorch.nn.functional.relu(data)input_info=[([10,10],"float32")]verify_model(ReLU(),input_info)verify_model(ReLU1(),input_info)deftest_relu6():"""test relay to relax for relu6"""classReLU6(Module):def__init__(self):super().__init__()self.relu6=torch.nn.ReLU6()defforward(self,data):returnself.relu6(data)input_info=[([10,10],"float32")]verify_model(ReLU6(),input_info)deftest_maxpool2d():"""test relay to relax for maxpool2d"""classMaxPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classMaxPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[2,2],dilation=[2,3])defforward(self,data):returnself.pool(data)classMaxPool2d3(Module):def__init__(self):super().__init__()self.pool=torch.nn.MaxPool2d(kernel_size=[4,4],padding=2,stride=2)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]verify_model(MaxPool2d(),input_info)verify_model(MaxPool2d2(),input_info)verify_model(MaxPool2d3(),input_info)deftest_avgpool2d():"""test relay to relax for avgpool2d"""classAvgPool2d(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[1,1])defforward(self,data):returnself.pool(data)classAvgPool2d2(Module):def__init__(self):super().__init__()self.pool=torch.nn.AvgPool2d(kernel_size=[4,4],stride=2,padding=2,ceil_mode=True)defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]verify_model(AvgPool2d(),input_info)verify_model(AvgPool2d2(),input_info)deftest_adaptive_avgpool2d():"""test relay to relax for adaptive_avgpool2d"""classAdaptiveAvgPool2d0(Module):def__init__(self):super().__init__()self.pool=torch.nn.AdaptiveAvgPool2d([10,10])defforward(self,data):returnself.pool(data)input_info=[([1,3,10,10],"float32")]verify_model(AdaptiveAvgPool2d0(),input_info)deftest_flatten():"""test relay to relax for flatten"""classFlatten(Module):def__init__(self):super().__init__()self.f=torch.nn.Flatten(2,-1)defforward(self,data):returnself.f(data)input_info=[([1,3,10,10],"float32")]verify_model(Flatten(),input_info,opt_config={"opt_level":3},build_target="llvm")verify_model(torch.nn.Flatten(2,-1),input_info,opt_config={"opt_level":3},build_target="llvm")deftest_batchnorm2d():"""test relay to relax for batchnorm2d"""classBatchNorm2d(Module):def__init__(self):super().__init__()self.batchnorm=torch.nn.BatchNorm2d(3)defforward(self,data):returnself.batchnorm(data)input_info=[([1,3,10,10],"float32")]verify_model(BatchNorm2d(),input_info,build_target="llvm")deftest_embedding():"""test relay to relax for embedding"""classEmbedding(Module):def__init__(self):super().__init__()self.embedding=torch.nn.Embedding(10,3)defforward(self,data):returnself.embedding(data)verify_model(Embedding(),[([4],"int64")])verify_model(Embedding(),[([4,5],"int64")])deftest_layernorm():"""test relay to relax for layernorm"""classLayerNorm(Module):def__init__(self):super().__init__()self.layernorm=torch.nn.LayerNorm(10)defforward(self,data):returnself.layernorm(data)input_info=[([1,10,10],"float32")]verify_model(LayerNorm(),input_info)deftest_functional_layernorm():"""test relay to relax for functional_layernorm"""classLayerNorm(Module):def__init__(self,shape):super().__init__()self.weight=torch.nn.Parameter(torch.ones(shape))self.bias=torch.nn.Parameter(torch.zeros(shape))defforward(self,data):returntorch.nn.functional.layer_norm(data,self.weight.shape,self.weight,self.bias,1e-5)input_info=[([1,10,10],"float32")]verify_model(LayerNorm((10)),input_info)deftest_cross_entropy():"""test relay to relax for cross_entropy"""classCrossEntropy1(Module):def__init__(self):super().__init__()self.loss=torch.nn.CrossEntropyLoss()defforward(self,logits,targets):returnself.loss(logits,targets)classCrossEntropy2(Module):def__init__(self):super().__init__()self.weight=torch.nn.Parameter(torch.ones((2,)))self.loss=torch.nn.CrossEntropyLoss(weight=self.weight)defforward(self,logits,targets):returnself.loss(logits,targets)classCrossEntropy3(Module):def__init__(self):super().__init__()self.loss=torch.nn.CrossEntropyLoss(ignore_index=1,reduction="sum")defforward(self,logits,targets):returnself.loss(logits,targets)input_info=[([3,2],"float32"),([3],"int64")]verify_model(CrossEntropy1(),input_info,opt_config={"opt_level":3},build_target="llvm")verify_model(CrossEntropy2(),input_info,opt_config={"opt_level":3},build_target="llvm")verify_model(CrossEntropy3(),input_info,opt_config={"opt_level":3},build_target="llvm")deftest_functional_cross_entropy():"""test relay to relax for functional_cross_entropy"""classCrossEntropy(Module):defforward(self,logits,targets):returntorch.nn.functional.cross_entropy(logits,targets)input_info=[([3,10],"float32"),([3],"int64")]verify_model(CrossEntropy(),input_info,opt_config={"opt_level":3},build_target="llvm")deftest_silu():"""test relay to relax for silu"""classSiLU(Module):def__init__(self):super().__init__()self.silu=torch.nn.SiLU()defforward(self,data):returnself.silu(data)classSiLU2(Module):defforward(self,data):returntorch.nn.functional.silu(data)input_info=[([1,3,10,10],"float32")]verify_model(SiLU(),input_info,build_target="llvm")verify_model(SiLU2(),input_info,build_target="llvm")deftest_groupnorm():"""test relay to relax for groupnorm"""classGroupNorm(Module):def__init__(self):super().__init__()self.groupnorm=torch.nn.GroupNorm(3,3)defforward(self,data):returnself.groupnorm(data)input_info=[([1,3,10,10],"float32")]verify_model(GroupNorm(),input_info)deftest_softmax():"""test relay to relax for softmax"""classSoftmax(Module):def__init__(self):super().__init__()self.softmax=torch.nn.Softmax(dim=1)defforward(self,data):returnself.softmax(data)input_info=[([1,3,10,10],"float32")]verify_model(Softmax(),input_info)deftest_binary():"""test relay to relax for binary"""input_info1=[([1,3,10,10],"float32"),([1,3,10,10],"float32")]input_info2=[([1,3,10,10],"float32")]# AddclassAdd1(Module):defforward(self,lhs,rhs):returnlhs+rhsclassAdd2(Module):defforward(self,lhs):returnlhs+1.0verify_model(Add1(),input_info1,opt_config={"opt_level":3})verify_model(Add2(),input_info2,opt_config={"opt_level":3})# SubclassSub1(Module):defforward(self,lhs,rhs):returnlhs-rhsclassSub2(Module):defforward(self,lhs):returnlhs-1.0verify_model(Sub1(),input_info1,opt_config={"opt_level":3})verify_model(Sub2(),input_info2,opt_config={"opt_level":3})# MulclassMul1(Module):defforward(self,lhs,rhs):returnlhs*rhsclassMul2(Module):defforward(self,lhs):returnlhs*1.0verify_model(Mul1(),input_info1,opt_config={"opt_level":3})verify_model(Mul2(),input_info2)# True divclassTrueDiv1(Module):defforward(self,lhs,rhs):returnlhs/rhsclassTrueDiv2(Module):defforward(self,lhs):returnlhs/1.0verify_model(TrueDiv1(),input_info1,opt_config={"opt_level":3})verify_model(TrueDiv2(),input_info2)# Floor divclassFloorDiv1(Module):defforward(self,lhs,rhs):returnlhs//rhsclassFloorDiv2(Module):defforward(self,lhs):returnlhs//1.0verify_model(FloorDiv1(),input_info1,opt_config={"opt_level":3})verify_model(FloorDiv2(),input_info2,opt_config={"opt_level":3})# PowerclassPower1(Module):defforward(self,lhs,rhs):returnlhs**rhsclassPower2(Module):defforward(self,lhs):returnlhs**1.0verify_model(Power1(),input_info1,opt_config={"opt_level":3})verify_model(Power2(),input_info2,opt_config={"opt_level":3})# LTclassLT1(Module):defforward(self,lhs,rhs):returnlhs<rhsclassLT2(Module):defforward(self,lhs):returnlhs<1.0verify_model(LT1(),input_info1,opt_config={"opt_level":3})verify_model(LT2(),input_info2,opt_config={"opt_level":3})deftest_squeeze():"""test relay to relax for squeeze"""classSqueeze1(Module):defforward(self,data):returndata.squeeze(1)classSqueeze2(Module):defforward(self,data):returndata.squeeze()input_info=[([3,1,4,1],"float32")]verify_model(Squeeze1(),input_info)verify_model(Squeeze2(),input_info)deftest_unsqueeze():"""test relay to relax for unsqueeze"""classUnsqueeze1(Module):defforward(self,data):returndata.unsqueeze(1)classUnsqueeze2(Module):defforward(self,data):returndata.unsqueeze(-1)input_info=[([1,3,10,10],"float32")]verify_model(Unsqueeze1(),input_info)verify_model(Unsqueeze2(),input_info)deftest_getitem():"""test relay to relax for getitem"""classSlice1(Module):defforward(self,x):returnx[0,1::2,:,:3]classSlice2(Module):defforward(self,x):returnx[:,None,None,:,None]verify_model(Slice1(),[([1,3,10,10],"float32")],build_target="ignore")verify_model(Slice2(),[([8,16],"float32")],build_target="llvm")deftest_unary():"""test relay to relax for unary"""input_info=[([1,3,10,10],"float32")]# sinclassSin(Module):defforward(self,data):returntorch.sin(data)verify_model(Sin(),input_info)# cosclassCos(Module):defforward(self,data):returntorch.cos(data)verify_model(Cos(),input_info)# expclassExp(Module):defforward(self,data):returntorch.exp(data)verify_model(Exp(),input_info)# sqrtclassSqrt(Module):defforward(self,data):returntorch.sqrt(data)verify_model(Sqrt(),input_info)# sigmoidclassSigmoid(Module):defforward(self,data):returntorch.sigmoid(data)verify_model(Sigmoid(),input_info)# roundclassRound(Module):defforward(self,data):returntorch.round(data)verify_model(Round(),input_info)deftest_gelu():"""test relay to relax for gelu"""classGelu(Module):defforward(self,data):returntorch.nn.functional.gelu(data)input_info=[([1,3,10,10],"float32")]verify_model(Gelu(),input_info)deftest_tanh():"""test relay to relax for tanh"""classTanh(Module):defforward(self,data):returntorch.tanh(data)input_info=[([1,3,10,10],"float32")]verify_model(Tanh(),input_info)deftest_clamp():"""test relay to relax for clamp"""classClamp(Module):defforward(self,data):returntorch.clamp(data,min=0.1,max=0.5)input_info=[([1,3,10,10],"float32")]verify_model(Clamp(),input_info)deftest_interpolate():"""test relay to relax for interpolate"""classInterpolate(Module):defforward(self,data):returntorch.nn.functional.interpolate(data,(5,5))input_info=[([1,3,10,10],"float32")]verify_model(Interpolate(),input_info,build_target="llvm")deftest_addmm():"""test relay to relax for addmm"""classAddmm(Module):defforward(self,x_1,x_2,x_3):returntorch.addmm(x_1,x_2,x_3)input_info=[([10,10],"float32"),([10,10],"float32"),([10,10],"float32"),]verify_model(Addmm(),input_info,build_target="llvm")deftest_split():"""test relay to relax for split"""classSplit(Module):defforward(self,data):returntorch.split(data,1,dim=1)input_info=[([1,3,10,10],"float32")]verify_model(Split(),input_info,build_target="llvm")deftest_cumsum():"""test relay to relax for cumsum"""classCumsum(Module):defforward(self,data):returntorch.cumsum(data,dim=1,dtype=torch.int32)input_info=[([1,2,3,4],"float32")]verify_model(Cumsum(),input_info)deftest_chunk():"""test relay to relax for chunk"""classChunk(Module):defforward(self,data):returntorch.chunk(data,3,dim=1)input_info=[([1,3,10,10],"float32")]verify_model(Chunk(),input_info,build_target="llvm")deftest_inplace_fill():"""test relay to relax for inplace_fill"""classInplaceFill(Module):defforward(self,data):data.fill_(1.5)returndataverify_model(InplaceFill(),[([10,10],"float32")],build_target="llvm")deftest_arange():"""test relay to relax for arange"""classArange(Module):defforward(self,data):returntorch.arange(0,20,dtype=torch.int32)verify_model(Arange(),[([10,10],"float32")],opt_config={"opt_level":3},build_target="llvm")deftest_empty():"""test relay to relax for empty"""classEmpty(Module):defforward(self,data):returntorch.empty((10,10),dtype=torch.float32)verify_model(Empty(),[([10,10],"float32")],opt_config={"opt_level":3},build_target="ignore")deftest_tensor():"""test relay to relax for tensor"""classEmpty1(Module):defforward(self,data):returntorch.tensor(3,dtype=torch.float32)classEmpty2(Module):defforward(self,data):returntorch.tensor(3)verify_model(Empty1(),[([10,10],"float32")],build_target="llvm")verify_model(Empty2(),[([10,10],"float32")],build_target="llvm")deftest_tril():"""test relay to relax for tril"""classTril(Module):defforward(self,data):returntorch.tril(data,1)classInplaceTril(Module):defforward(self,data):data.tril_(1)returndatainput_info=[([10,10],"float32")]verify_model(Tril(),input_info)verify_model(InplaceTril(),input_info)deftest_triu():"""test relay to relax for triu"""classTriu(Module):defforward(self,data):returntorch.triu(data,1)classInplaceTriu(Module):defforward(self,data):data.triu_(1)returndatainput_info=[([10,10],"float32")]verify_model(Triu(),input_info)verify_model(InplaceTriu(),input_info)deftest_new_ones():"""test relay to relax for new_ones"""classNewOnes(Module):defforward(self,x):returnx.new_ones(1,2,3)input_info=[([1,2,3],"float32")]verify_model(NewOnes(),input_info,build_target="llvm")deftest_expand():"""test relay to relax for expand"""classExpand(Module):defforward(self,x):returnx.expand(4,2,3,4)input_info=[([1,2,3,4],"float32")]verify_model(Expand(),input_info,build_target="llvm")deftest_reduce():"""test relay to relax for reduce"""# sumclassSum(Module):defforward(self,x):returntorch.sum(x,(2,1))input_info=[([1,2,3,4],"float32")]verify_model(Sum(),input_info)deftest_datatype():"""test relay to relax for datatype"""input_info=[([1,2,3,4],"float32")]# floatclassToFloat(Module):defforward(self,x):returnx.float()verify_model(ToFloat(),input_info,build_target="llvm")# halfclassToHalf(Module):defforward(self,x):returnx.half()verify_model(ToHalf(),input_info)# typeclassType(Module):defforward(self,x):returnx.type(torch.float32)verify_model(Type(),input_info,build_target="llvm")deftest_permute():"""test relay to relax for permute"""classPermute(Module):defforward(self,x):returnx.permute(0,3,2,1)input_info=[([1,2,3,4],"float32")]verify_model(Permute(),input_info)deftest_reshape():"""test relay to relax for reshape"""classReshape(Module):defforward(self,x):returnx.reshape(2,12)input_info=[([1,2,3,4],"float32")]verify_model(Reshape(),input_info)deftest_transpose():"""test relay to relax for transpose"""classTranspose(Module):defforward(self,x):returnx.transpose(1,3)input_info=[([1,2,3,4],"float32")]verify_model(Transpose(),input_info)deftest_view():"""test relay to relax for view"""classView(Module):defforward(self,x):returnx.view(2,12)input_info=[([1,2,3,4],"float32")]verify_model(View(),input_info)deftest_keep_params():"""test relay to relax for keep_params"""classConv2D1(Module):def__init__(self):super().__init__()self.conv=torch.nn.Conv2d(3,6,7,bias=True)defforward(self,data):returnself.conv(data)verify_model(Conv2D1(),[([1,3,10,10],"float32")])deftest_unwrap_unit_return_tuple():"""test relay to relax for unwrap_unit_return_tuple"""classIdentity(Module):defforward(self,x):return(x,)verify_model(Identity(),[([256,256],"float32")],build_target="llvm")deftest_no_bind_return_tuple():"""test relay to relax for no_bind_return_tuple"""classIdentity(Module):defforward(self,x,y):return(x,y)input_info=[([256,256],"float32"),([256,256],"float32")]verify_model(Identity(),input_info)deftest_argmax():"""test relay to relax for argmax"""classArgmax1(Module):defforward(self,data):returntorch.argmax(data,dim=-1)classArgmax2(Module):defforward(self,data):returntorch.argmax(data,dim=-1,keepdim=True)verify_model(Argmax1(),[([256,256],"float32")])verify_model(Argmax2(),[([256,256],"float32")])deftest_to():"""test relay to relax for to"""classTo1(Module):defforward(self,data):returndata.to(torch.float16)classTo2(Module):defforward(self,data):returndata.to("cpu")verify_model(To1(),[([256,256],"float32")])verify_model(To2(),[([256,256],"float32")])deftest_mean():"""test relay to relax for mean"""classMean(Module):defforward(self,data):returndata.mean(-1)classMeanKeepDim(Module):defforward(self,data):returndata.mean(-1,keepdim=True)verify_model(Mean(),[([256,256],"float32")])verify_model(MeanKeepDim(),[([256,256],"float32")])deftest_rsqrt():"""test relay to relax for rsqrt"""classRsqrt(Module):defforward(self,data):returntorch.rsqrt(data)verify_model(Rsqrt(),[([256,256],"float32")])deftest_neg():"""test relay to relax for neg"""classNeg(Module):defforward(self,data):return-dataverify_model(Neg(),[([256,256],"float32")])deftest_max():"""test relay to relax for max"""classMax(Module):defforward(self,x,y):returntorch.max(x,y)verify_model(Max(),[([256,256],"float32"),([256,256],"float32")])deftest_name_string_with_colon():"""test name string with colons, e.g., TFLite default input name 'serving_default_input:0' """dtype="float32"x_var=relay.var("input_0:0",shape=(3,5),dtype=dtype)y_var=relay.var("input_1:0",shape=(3,5),dtype=dtype)z_add=relay.add(x_var,y_var)func=relay.Function([x_var,y_var],z_add)mod=IRModule()mod["main"]=functry:graph,_=translate.from_relay(mod)exceptExceptionaserr:raiseRuntimeError(f"Translation from relay to graph failed: {err}")inspect=graph.inspect()expected={"inputs":[{"name":"input_0:0","shape":[3,5],"dtype":dtype,"layout":""},{"name":"input_1:0","shape":[3,5],"dtype":dtype,"layout":""},],"outputs":[{"name":"add","shape":[3,5],"dtype":dtype,"layout":""}],"nodes":{"total":3,"input":2,"add":1},}assertmsc_utils.dict_equal(inspect,expected),"Inspect {} mismatch with expected {}".format(inspect,expected)