deftest_min():classModel(Module):deftest(self,x:Tensor):z0=op.min(x,axis=[1,2],keepdims=True)returnz0# fmt: off@R.functiondeftest(x:R.Tensor((3,5,2,4),dtype="float32"),_io:R.Object)->R.Tuple(R.Tensor((3,1,1,4),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":2})withR.dataflow():min:R.Tensor((3,1,1,4),dtype="float32")=R.min(x,axis=[1,2],keepdims=True)gv1:R.Tuple(R.Tensor((3,1,1,4),dtype="float32"),R.Tuple(R.Object))=min,(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([3,5,2,4],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_manipulate():classModel(Module):deftest(self,x:Tensor):z0=op.broadcast_to(x,[2,5,2])z1=op.permute_dims(x,[2,1,0])z2=op.reshape(x,[1,10])z3=op.repeat(x,repeats=2,axis=1)z4=op.squeeze(x,0)z5=op.unsqueeze(x,0)z6=op.concat([x,x],dim=0)return(z0,z1,z2,z3,z4,z5,z6)# fmt: off@R.functiondeftest(x:R.Tensor((1,5,2),dtype="float32"),_io:R.Object)->R.Tuple(R.Tuple(R.Tensor((2,5,2),dtype="float32"),R.Tensor((2,5,1),dtype="float32"),R.Tensor((1,10),dtype="float32"),R.Tensor((1,10,2),dtype="float32"),R.Tensor((5,2),dtype="float32"),R.Tensor((1,1,5,2),dtype="float32"),R.Tensor((2,5,2),dtype="float32")),R.Tuple(R.Object)):R.func_attr({"num_input":2})withR.dataflow():broadcast_to:R.Tensor((2,5,2),dtype="float32")=R.broadcast_to(x,R.shape([2,5,2]))permute_dims:R.Tensor((2,5,1),dtype="float32")=R.permute_dims(x,axes=[2,1,0])reshape:R.Tensor((1,10),dtype="float32")=R.reshape(x,R.shape([1,10]))repeat:R.Tensor((1,10,2),dtype="float32")=R.repeat(x,repeats=2,axis=1)squeeze:R.Tensor((5,2),dtype="float32")=R.squeeze(x,axis=[0])unsqueeze:R.Tensor((1,1,5,2),dtype="float32")=R.expand_dims(x,axis=0)concat:R.Tensor((2,5,2),dtype="float32")=R.concat([x,x],axis=0)gv1:R.Tuple(R.Tuple(R.Tensor((2,5,2),dtype="float32"),R.Tensor((2,5,1),dtype="float32"),R.Tensor((1,10),dtype="float32"),R.Tensor((1,10,2),dtype="float32"),R.Tensor((5,2),dtype="float32"),R.Tensor((1,1,5,2),dtype="float32"),R.Tensor((2,5,2),dtype="float32")),R.Tuple(R.Object))=(broadcast_to,permute_dims,reshape,repeat,squeeze,unsqueeze,concat),(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([1,5,2],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_index():classModel(Module):deftest(self,x:Tensor,y:Tensor):z0=op.take(x,y,axis=2)returnz0# fmt: off@R.functiondeftest(x:R.Tensor((2,1,10),dtype="float32"),y:R.Tensor((5,),dtype="int32"),_io:R.Object)->R.Tuple(R.Tensor((2,1,5),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":3})withR.dataflow():take:R.Tensor((2,1,5),dtype="float32")=R.take(x,y,axis=2)gv1:R.Tuple(R.Tensor((2,1,5),dtype="float32"),R.Tuple(R.Object))=take,(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,params=m.export_tvm(spec={"test":{"x":spec.Tensor([2,1,10],"float32"),"y":spec.Tensor([5],"int32")}},debug=True,)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_datatype():classModel(Module):deftest(self,x:Tensor):z0=op.astype(x,"float16")returnz0# fmt: off@R.functiondeftest(x:R.Tensor((2,1,10),dtype="float32"),_io:R.Object)->R.Tuple(R.Tensor((2,1,10),dtype="float16"),R.Tuple(R.Object)):R.func_attr({"num_input":2})withR.dataflow():astype:R.Tensor((2,1,10),dtype="float16")=R.astype(x,dtype="float16")gv1:R.Tuple(R.Tensor((2,1,10),dtype="float16"),R.Tuple(R.Object))=astype,(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([2,1,10],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_image():classModel(Module):deftest(self,x:Tensor,weight:Tensor,bias:Tensor):padded=op.pad(x,[0,0,0,0,1,1,1,1])conv2d=op.conv2d(padded,weight,bias)interpolate=op.interpolate(x,size=[40,40])# type: ignorereturn(conv2d,interpolate)@R.functiondeftest(x:R.Tensor((1,3,32,32),dtype="float32"),weight:R.Tensor((32,3,3,3),dtype="float32"),bias:R.Tensor((32,),dtype="float32"),_io:R.Object,)->R.Tuple(R.Tuple(R.Tensor((1,32,32,32),dtype="float32"),R.Tensor((1,3,40,40),dtype="float32")),R.Tuple(R.Object),):R.func_attr({"num_input":4})withR.dataflow():lv0:R.Tensor((1,3,34,34),dtype="float32")=R.nn.pad(x,(0,0,0,0,1,1,1,1))lv1:R.Tensor((1,32,32,32),dtype="float32")=R.nn.conv2d(lv0,weight,strides=[1,1],padding=[0,0,0,0],dilation=[1,1],groups=1,data_layout="NCHW",kernel_layout="OIHW",out_layout="NCHW",out_dtype="void",)lv2:R.Tensor((1,32,1,1),dtype="float32")=R.reshape(bias,R.shape([1,32,1,1]))conv2d:R.Tensor((1,32,32,32),dtype="float32")=R.add(lv1,lv2)interpolate:R.Tensor((1,3,40,40),dtype="float32")=R.image.resize2d(x,R.shape([40,40]),roi=[T.float32(0),T.float32(0),T.float32(0),T.float32(0)],layout="NCHW",method="nearest_neighbor",coordinate_transformation_mode="asymmetric",rounding_method="round",cubic_alpha=-0.5,cubic_exclude=0,extrapolation_value=0,out_dtype="void",)gv1:R.Tuple(R.Tuple(R.Tensor((1,32,32,32),dtype="float32"),R.Tensor((1,3,40,40),dtype="float32"),),R.Tuple(R.Object),)=(conv2d,interpolate),(_io,)R.output(gv1)returngv1m=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([1,3,32,32],"float32"),"weight":spec.Tensor([32,3,3,3],"float32"),"bias":spec.Tensor([32],"float32"),}},debug=True,)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_chunk():classModel(Module):deftest(self,x:Tensor):chunk=op.chunk(x,chunks=4)returnchunk@R.functiondeftest(x:R.Tensor((8,),dtype="float32"),_io:R.Object)->R.Tuple(R.Tuple(R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),),R.Tuple(R.Object),):R.func_attr({"num_input":2})withR.dataflow():chunk:R.Tuple(R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),)=R.split(x,indices_or_sections=4,axis=0)chunk_0:R.Tensor((2,),dtype="float32")=chunk[0]chunk_1:R.Tensor((2,),dtype="float32")=chunk[1]chunk_2:R.Tensor((2,),dtype="float32")=chunk[2]chunk_3:R.Tensor((2,),dtype="float32")=chunk[3]gv1:R.Tuple(R.Tuple(R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),R.Tensor((2,),dtype="float32"),),R.Tuple(R.Object),)=(chunk_0,chunk_1,chunk_2,chunk_3),(_io,)R.output(gv1)returngv1m=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([8],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_nn():classModel(Module):deftest(self,x:Tensor,weight:Tensor,bias:Tensor):relu_out=op.relu(x)silu_out=op.silu(x)gelu_out=op.gelu(x)sigmoid_out=op.sigmoid(x)tanh_out=op.tanh(x)exp_out=op.exp(x)negative_out=op.negative(x)softplus_out=op.softplus(x,beta=1.0,threshold=20.0)softmax_out=op.softmax(x,axis=2)prelu_out=op.prelu(x,alpha=bias)rms_norm_out=op.rms_norm(x,weight,axes=[-2,-1])rms_norm_with_bias_out=op.rms_norm(x,weight,axes=[-2,-1])group_norm_out=op.group_norm(x,num_groups=1,weight=bias,bias=bias)returnx@R.functiondeftest(x:R.Tensor((2,3,4,5),dtype="float32"),weight:R.Tensor((4,5),dtype="float32"),bias:R.Tensor((3,),dtype="float32"),_io:R.Object,)->R.Tuple(R.Tensor((2,3,4,5),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":4})withR.dataflow():relu:R.Tensor((2,3,4,5),dtype="float32")=R.nn.relu(x)silu:R.Tensor((2,3,4,5),dtype="float32")=R.nn.silu(x)gelu:R.Tensor((2,3,4,5),dtype="float32")=R.nn.gelu(x)sigmoid:R.Tensor((2,3,4,5),dtype="float32")=R.sigmoid(x)tanh:R.Tensor((2,3,4,5),dtype="float32")=R.tanh(x)exp:R.Tensor((2,3,4,5),dtype="float32")=R.exp(x)negative:R.Tensor((2,3,4,5),dtype="float32")=R.negative(x)softplus:R.Tensor((2,3,4,5),dtype="float32")=R.nn.softplus(x,beta=1.0,threshold=20.0)softmax:R.Tensor((2,3,4,5),dtype="float32")=R.nn.softmax(x,axis=2)prelu:R.Tensor((2,3,4,5),dtype="float32")=R.nn.prelu(x,bias)rms_norm:R.Tensor((2,3,4,5),dtype="float32")=R.nn.rms_norm(x,weight,axes=[-2,-1],epsilon=1.0000000000000001e-05)rms_norm1:R.Tensor((2,3,4,5),dtype="float32")=R.nn.rms_norm(x,weight,axes=[-2,-1],epsilon=1.0000000000000001e-05)group_norm:R.Tensor((2,3,4,5),dtype="float32")=R.nn.group_norm(x,bias,bias,num_groups=1,channel_axis=1,axes=[2,3])gv1:R.Tuple(R.Tensor((2,3,4,5),dtype="float32"),R.Tuple(R.Object))=x,(_io,)R.output(gv1)returngv1m=Model()irmodule,params=m.export_tvm(spec={"test":{"x":spec.Tensor([2,3,4,5],"float32"),"weight":spec.Tensor([4,5],"float32"),"bias":spec.Tensor([3],"float32"),}},debug=True,)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_create():classModel(Module):deftest(self,x:Tensor):triu_out=op.triu(x)full_with_scalar_out=op.full([10,10],fill_value=10)# type: ignorefull_with_FloatImm_out=op.full([10,10],fill_value=tir.FloatImm(dtype="float32",value=10))full_with_Tensor_out=op.full([10,10],fill_value=Tensor.from_scalar(10,dtype="float32"))zeros_out=op.zeros([10,10])zeros_fp16_out=op.zeros([10,10],dtype="float16")returnx# fmt: off@R.functiondeftest(x:R.Tensor((10,10),dtype="float32"),_io:R.Object)->R.Tuple(R.Tensor((10,10),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":2})withR.dataflow():triu:R.Tensor((10,10),dtype="float32")=R.triu(x,k=0)full:R.Tensor((10,10),dtype="float32")=R.full(R.shape([10,10]),R.const(10,"float32"),dtype="float32")full1:R.Tensor((10,10),dtype="float32")=R.full(R.shape([10,10]),R.const(10,"float32"),dtype="float32")full2:R.Tensor((10,10),dtype="float32")=R.full(R.shape([10,10]),R.const(10,"float32"),dtype="float32")zeros:R.Tensor((10,10),dtype="float32")=R.zeros(R.shape([10,10]),dtype="float32")zeros1:R.Tensor((10,10),dtype="float16")=R.zeros(R.shape([10,10]),dtype="float16")gv1:R.Tuple(R.Tensor((10,10),dtype="float32"),R.Tuple(R.Object))=x,(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,params=m.export_tvm(spec={"test":{"x":spec.Tensor([10,10],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_timestep_embedding():classModel(Module):deftest(self,x:Tensor):get_timestep_out=op.get_timestep_embedding(x,10)returnget_timestep_out@R.functiondeftest(x:R.Tensor((3,),dtype="float32"),_io:R.Object)->R.Tuple(R.Tensor((3,10),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":2})withR.dataflow():lv1:R.Tensor((3,),dtype="float32")=R.astype(x,dtype="float32")lv2:R.Tensor((3,1),dtype="float32")=R.expand_dims(lv1,axis=[1])lv3:R.Tensor((5,),dtype="float32")=R.arange(R.prim_value(0),R.prim_value(5),R.prim_value(1),dtype="float32")lv4:R.Tensor((5,),dtype="float32")=R.multiply(R.const(-9.2103404998779297,"float32"),lv3)lv5:R.Tensor((5,),dtype="float32")=R.divide(lv4,R.const(4,"float32"))lv6:R.Tensor((5,),dtype="float32")=R.exp(lv5)lv7:R.Tensor((1,5),dtype="float32")=R.expand_dims(lv6,axis=[0])lv8:R.Tensor((3,5),dtype="float32")=R.multiply(lv2,lv7)lv9:R.Tensor((3,5),dtype="float32")=R.sin(lv8)lv10:R.Tensor((3,5),dtype="float32")=R.cos(lv8)lv11:R.Tensor((3,10),dtype="float32")=R.concat((lv9,lv10),axis=-1)get_timestep_embedding:R.Tensor((3,10),dtype="float32")=R.astype(lv11,dtype="float32")gv1:R.Tuple(R.Tensor((3,10),dtype="float32"),R.Tuple(R.Object))=(get_timestep_embedding,(_io,),)R.output(gv1)returngv1m=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([3],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_scaled_dot_product_attention():classModel(Module):deftest(self,query:Tensor,key:Tensor,value:Tensor):scaled_dot_product_attention=op.scaled_dot_product_attention(query,key,value)returnscaled_dot_product_attention@R.functiondeftest(query:R.Tensor((1,32,32,32),dtype="float32"),key:R.Tensor((1,32,32,32),dtype="float32"),value:R.Tensor((1,32,32,32),dtype="float32"),_io:R.Object,)->R.Tuple(R.Tensor((1,32,32,32),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":4})withR.dataflow():scaled_dot_product_attention:R.Tensor((1,32,32,32),dtype="float32")=R.nn.attention(query,key,value,scale=None,causal_mask=None)gv1:R.Tuple(R.Tensor((1,32,32,32),dtype="float32"),R.Tuple(R.Object))=(scaled_dot_product_attention,(_io,),)R.output(gv1)returngv1m=Model()irmodule,_=m.export_tvm(spec={"test":{"query":spec.Tensor([1,32,32,32],"float32"),"key":spec.Tensor([1,32,32,32],"float32"),"value":spec.Tensor([1,32,32,32],"float32"),}},debug=True,)tvm.ir.assert_structural_equal(irmodule["test"],test)deftest_tensor_expr_op():classModel(Module):deftest(self,x:Tensor):tensor_expr_op_out=op.tensor_expr_op(tensor_expr_func=lambdax:x+1,name_hint="add_one",args=[x])returntensor_expr_op_out# fmt: off@I.ir_moduleclassExpected:@T.prim_func(private=True)defadd_one(A:T.Buffer((T.int64(10),T.int64(10)),"float32"),T_add:T.Buffer((T.int64(10),T.int64(10)),"float32")):T.func_attr({"tir.noalias":T.bool(True)})# with T.block("root"):forax0,ax1inT.grid(T.int64(10),T.int64(10)):withT.block("T_add"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])T.reads(A[v_ax0,v_ax1])T.writes(T_add[v_ax0,v_ax1])T_add[v_ax0,v_ax1]=A[v_ax0,v_ax1]+T.float32(1)@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeftest(x:R.Tensor((10,10),dtype="float32"),_io:R.Object)->R.Tuple(R.Tensor((10,10),dtype="float32"),R.Tuple(R.Object)):cls=ExpectedR.func_attr({"num_input":2})withR.dataflow():lv1=R.call_tir(cls.add_one,(x,),out_sinfo=R.Tensor((10,10),dtype="float32"))gv1:R.Tuple(R.Tensor((10,10),dtype="float32"),R.Tuple(R.Object))=lv1,(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,_=m.export_tvm(spec={"test":{"x":spec.Tensor([10,10],"float32")}},debug=True)tvm.ir.assert_structural_equal(irmodule,Expected)deftest_tensor_ir_op():num_q_heads,num_kv_heads,head_dim=8,8,16fused_heads=num_q_heads+num_kv_heads*2dtype="float16"@T.prim_func(private=True)deffused_rope(# pylint: disable=too-many-localsvar_qkv:T.handle,var_q:T.handle,var_k:T.handle,var_v:T.handle,# Scalar arguments must be specified after tensor arguments,# including the output tensor arguments## TODO(Lunderberg): Update# `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue`# instead of `tir_vars`, so that the order can be consistent# between the function definition and the arguments in# `op.tensor_ir_op`.offset:T.int64,):batch_size=T.int64()seq_len=T.int64()qkv=T.match_buffer(var_qkv,(batch_size,seq_len,fused_heads,head_dim),dtype)q=T.match_buffer(var_q,(batch_size,seq_len,num_q_heads,head_dim),dtype)k=T.match_buffer(var_k,(batch_size,seq_len,num_kv_heads,head_dim),dtype)v=T.match_buffer(var_v,(batch_size,seq_len,num_kv_heads,head_dim),dtype)T.evaluate(offset)classModel(Module):deftest(self,qkv:Tensor,offset:tir.Var):tensor_expr_op_out=op.tensor_ir_op(fused_rope,"llama_fused_rope",args=[qkv,offset],out=[Tensor.placeholder((1,1,num_q_heads,head_dim),dtype),Tensor.placeholder((1,1,num_kv_heads,head_dim),dtype),Tensor.placeholder((1,1,num_kv_heads,head_dim),dtype),],)returntensor_expr_op_out# fmt: off@I.ir_moduleclassExpected:@T.prim_func(private=True)defllama_fused_rope(var_qkv:T.handle,var_q:T.handle,var_k:T.handle,var_v:T.handle,offset:T.int64):batch_size,seq_len=T.int64(),T.int64()qkv=T.match_buffer(var_qkv,(batch_size,seq_len,24,16),"float16")q=T.match_buffer(var_q,(batch_size,seq_len,8,16),"float16")k=T.match_buffer(var_k,(batch_size,seq_len,8,16),"float16")v=T.match_buffer(var_v,(batch_size,seq_len,8,16),"float16")T.evaluate(offset)@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeftest(qkv:R.Tensor((1,1,24,16),dtype="float16"),offset:R.Shape(["offset_1"]),_io:R.Object)->R.Tuple(R.Tuple(R.Tensor((1,1,8,16),dtype="float16"),R.Tensor((1,1,8,16),dtype="float16"),R.Tensor((1,1,8,16),dtype="float16")),R.Tuple(R.Object)):offset_1=T.int64()R.func_attr({"num_input":3})cls=ExpectedwithR.dataflow():lv1=R.call_tir(cls.llama_fused_rope,(qkv,),out_sinfo=[R.Tensor((1,1,8,16),dtype="float16"),R.Tensor((1,1,8,16),dtype="float16"),R.Tensor((1,1,8,16),dtype="float16")],tir_vars=R.shape([offset_1]))llama_fused_rope_0:R.Tensor((1,1,8,16),dtype="float16")=lv1[0]llama_fused_rope_1:R.Tensor((1,1,8,16),dtype="float16")=lv1[1]llama_fused_rope_2:R.Tensor((1,1,8,16),dtype="float16")=lv1[2]gv1:R.Tuple(R.Tuple(R.Tensor((1,1,8,16),dtype="float16"),R.Tensor((1,1,8,16),dtype="float16"),R.Tensor((1,1,8,16),dtype="float16")),R.Tuple(R.Object))=(llama_fused_rope_0,llama_fused_rope_1,llama_fused_rope_2),(_io,)R.output(gv1)returngv1# fmt: onm=Model()irmodule,_=m.export_tvm(spec={"test":{"qkv":spec.Tensor([1,1,fused_heads,head_dim],"float16"),"offset":int}},debug=True,)tvm.ir.assert_structural_equal(irmodule,Expected)deftest_tensor_ir_inplace_op():hidden_size=4096dtype="float16"@T.prim_funcdefinplace_take(var_weight:T.handle,var_pos:T.handle,var_embeddings:T.handle,offset:T.int64):T.func_attr({"tir.noalias":T.bool(True)})vocab_size=T.int64()weight=T.match_buffer(var_weight,(vocab_size,hidden_size),dtype)seq_len=T.int64()total_seq_len=T.int64()pos=T.match_buffer(var_pos,(seq_len,),"int32")embeddings=T.match_buffer(var_embeddings,(total_seq_len,hidden_size),dtype)forax0,ax1inT.grid(seq_len,hidden_size):withT.block("T_take"):v0,v1=T.axis.remap("SS",[ax0,ax1])T.reads(weight[pos[v0],v1],pos[v0])T.writes(embeddings[v0,v1])embeddings[v0+offset,v1]=weight[pos[v0],v1]classModel(Module):deftest(self,embedding_table:Tensor,input_ids:Tensor,embedding_dst:Tensor,offset:int):tensor_expr_op_out=op.tensor_ir_inplace_op(inplace_take,"inplace_take",args=[embedding_table,input_ids,embedding_dst,offset],inplace_indices=[2],out=Tensor.placeholder(embedding_dst.shape,embedding_dst.dtype),)returntensor_expr_op_out@I.ir_moduleclassExpected:@T.prim_funcdefinplace_take(var_weight:T.handle,var_pos:T.handle,var_embeddings:T.handle,offset:T.int64):T.func_attr({"tir.noalias":T.bool(True)})vocab_size=T.int64()weight=T.match_buffer(var_weight,(vocab_size,hidden_size),dtype)seq_len=T.int64()total_seq_len=T.int64()pos=T.match_buffer(var_pos,(seq_len,),"int32")embeddings=T.match_buffer(var_embeddings,(total_seq_len,hidden_size),dtype)forax0,ax1inT.grid(seq_len,hidden_size):withT.block("T_take"):v0,v1=T.axis.remap("SS",[ax0,ax1])T.reads(weight[pos[v0],v1],pos[v0])T.writes(embeddings[v0,v1])embeddings[v0+offset,v1]=weight[pos[v0],v1]@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeftest(embedding_table:R.Tensor(("vocab_size",hidden_size),dtype),input_ids:R.Tensor(("seq_len",),"int32"),embedding_dst:R.Tensor(("total_seq_len",hidden_size),dtype),offset:R.Shape(["offset_1"]),packed_params:R.Tuple,)->R.Tensor(("total_seq_len",hidden_size),dtype):total_seq_len=T.int64()offset_1=T.int64()R.func_attr({"num_input":4})cls=ExpectedwithR.dataflow():lv1=R.call_tir_inplace(cls.inplace_take,(embedding_table,input_ids,embedding_dst),out_sinfo=R.Tensor((total_seq_len,hidden_size),dtype),inplace_indices=[2],tir_vars=R.shape([offset_1]),)gv1:R.Tensor((total_seq_len,hidden_size),dtype)=lv1R.output(gv1)returngv1m=Model()irmodule,_=m.export_tvm(spec={"test":{"embedding_table":spec.Tensor(["vocab_size",hidden_size],dtype),"input_ids":spec.Tensor(["seq_len"],"int32"),"embedding_dst":spec.Tensor(["total_seq_len",hidden_size],dtype),"offset":int,"$":{"param_mode":"packed","effect_mode":"none",},},},debug=True,)tvm.ir.assert_structural_equal(irmodule,Expected)deftest_tensor_ir_op_no_tir_var():@T.prim_func(private=True)deftir_func(A:T.Buffer((16,16),"float32"),B:T.Buffer((16,16),"float32")):T.evaluate(0)classModel(Module):deftest(self,A:Tensor):tensor_expr_op_out=op.tensor_ir_op(tir_func,"tir_func",args=[A],out=[Tensor.placeholder((16,16),"float32")],)returntensor_expr_op_out@I.ir_moduleclassExpected:@T.prim_func(private=True)deftir_func(A:T.Buffer((16,16),"float32"),B:T.Buffer((16,16),"float32")):T.evaluate(0)@R.functiondeftest(A:R.Tensor((16,16),dtype="float32"))->R.Tensor((16,16),dtype="float32"):R.func_attr({"num_input":1})cls=ExpectedwithR.dataflow():lv=R.call_tir(cls.tir_func,(A,),out_sinfo=R.Tensor((16,16),dtype="float32"))gv:R.Tensor((16,16),dtype="float32")=lvR.output(gv)returngvm=Model()irmodule,_=m.export_tvm(spec={"test":{"A":spec.Tensor([16,16],"float32")}})tvm.ir.assert_structural_equal(irmodule,Expected)deftest_extern():classModel(Module):deftest(self,q:Tensor,k:Tensor,v:Tensor):b,s,h_q,d=q.shapetensor_expr_op_out=op.extern(name="flashinfer.single_decode",args=[q,k,v,0,0,1.0,10000.0],out=Tensor.placeholder((b,s,h_q*d),dtype="float16"),)returntensor_expr_op_out# fmt: off@I.ir_moduleclassExpected:@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeftest(q:R.Tensor((1,1,16,8),dtype="float32"),k:R.Tensor((64,16,8),dtype="float32"),v:R.Tensor((64,16,8),dtype="float32"),_io:R.Object)->R.Tuple(R.Tensor((1,1,128),dtype="float16"),R.Tuple(R.Object)):R.func_attr({"num_input":4})withR.dataflow():flashinfer_single_decode=R.call_dps_packed("flashinfer.single_decode",(q,k,v,R.prim_value(0),R.prim_value(0),R.prim_value(T.float64(1)),R.prim_value(T.float64(10000))),out_sinfo=R.Tensor((1,1,128),dtype="float16"))gv1:R.Tuple(R.Tensor((1,1,128),dtype="float16"),R.Tuple(R.Object))=flashinfer_single_decode,(_io,)R.output(gv1)returngv1# fmt: onbatch,seq,t,d,h_q,h_kv=1,1,64,8,16,16m=Model()irmodule,_=m.export_tvm(spec={"test":{"q":spec.Tensor([batch,seq,h_q,d],"float32"),"k":spec.Tensor([t,h_kv,d],"float32"),"v":spec.Tensor([t,h_kv,d],"float32"),}},debug=True,)tvm.ir.assert_structural_equal(irmodule,Expected)deftest_empty():@tvm.register_func("test_empty_assert",override=True)deftest_empty_assert(_lineo,x):assertx.shape==(10,10)assertx.dtype=="float32"classModel(Module):deftest(self):result=op.empty([10,10],dtype="float32")op.debug_func("test_empty_assert",result)returnresultirmodule,_=Model().export_tvm(spec={"test":{}},debug=True)ex=tvm.compile(irmodule,"llvm")vm=relax.VirtualMachine(ex,tvm.cpu())effects=vm["_initialize_effect"]()vm["test"](*effects)@tvm.testing.requires_cudadeftest_multinomial_from_uniform():prob_shape=(3,5)sample_shape=(6,1)classModel(Module):deffoo(self,prob:Tensor,uniform_sample:Tensor,sample_indices:Tensor):z0=op.multinomial_from_uniform(prob,uniform_sample,sample_indices)returnz0# fmt: off@I.ir_moduleclassExpected:@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeffoo(prob:R.Tensor((3,5),dtype="float32"),uniform_sample:R.Tensor((6,1),dtype="float32"),sample_indices:R.Tensor((6,1),dtype="int64"),_io:R.Object)->R.Tuple(R.Tensor((6,1),dtype="int64"),R.Tuple(R.Object)):R.func_attr({"num_input":4})withR.dataflow():multinomial_from_uniform:R.Tensor((6,1),dtype="int64")=R.multinomial_from_uniform(prob,uniform_sample,sample_indices,dtype="int64")gv1:R.Tuple(R.Tensor((6,1),dtype="int64"),R.Tuple(R.Object))=multinomial_from_uniform,(_io,)R.output(gv1)returngv1# fmt: onm=Model()mod,_=m.export_tvm(spec={"foo":{"prob":spec.Tensor(prob_shape,"float32"),"uniform_sample":spec.Tensor(sample_shape,"float32"),"sample_indices":spec.Tensor(sample_shape,"int64"),}},debug=True,)tvm.ir.assert_structural_equal(mod,Expected)target=tvm.target.Target("cuda",host="llvm")withtarget:mod=relax.backend.DispatchSampling()(mod)mod=tir.transform.DefaultGPUSchedule()(mod)ex=tvm.compile(mod,target)dev=tvm.device(str(target),0)vm=relax.VirtualMachine(ex,dev)effects=vm["_initialize_effect"]()np_rand=np.random.rand(*prob_shape).astype(np.float32)# normalize it to get the random probnp_prob=np_rand/np_rand.sum(axis=1,keepdims=True)nd_prob=tvm.nd.array(np_prob,dev)# special sample to get deterministic resultsnd_sample=tvm.nd.array(np.array([[1],[0],[1],[1],[0],[1]]).astype(np.float32),dev)nd_sample_indices=tvm.nd.array(np.array([[0],[1],[1],[2],[2],[2]]).astype(np.int64),dev)inputs=[nd_prob,nd_sample,nd_sample_indices,effects]res=vm["foo"](*inputs)tvm.testing.assert_allclose(res[0].numpy(),np.array([[4],[0],[4],[4],[0],[4]]).astype(np.int64))@tvm.testing.requires_gpudeftest_sample_top_p_top_k_from_sorted_prob():prob_shape=(2,3)sample_shape=(3,1)classModel(Module):deffoo(self,prob:Tensor,index:Tensor,top_p:Tensor,top_k:Tensor,uniform_sample:Tensor,sample_indices:Tensor,):z0=op.sample_top_p_top_k_from_sorted_prob(prob,index,top_p,top_k,uniform_sample,sample_indices)returnz0# fmt: off@I.ir_moduleclassExpected:@T.prim_func(private=True)defget_index_from_sorted(A:T.handle,B:T.handle,C:T.handle,D:T.handle,E:T.handle,F:T.handle):batch,vocab_size=T.int64(is_size_var=True),T.int64(is_size_var=True)cumsum_sorted=T.match_buffer(A,(batch,vocab_size))indices=T.match_buffer(B,(batch,vocab_size),"int64")renorm_prob=T.match_buffer(C,(batch,1))out_batch=T.int64(is_size_var=True)usample=T.match_buffer(D,(out_batch,1))sample_indices=T.match_buffer(E,(out_batch,1),"int64")output_index=T.match_buffer(F,(out_batch,1),"int64")# with T.block("root"):forax0,ax1inT.grid(out_batch,vocab_size):withT.block("T_get_index_from_sorted"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])T.reads(usample[v_ax0,T.int64(0)],cumsum_sorted[sample_indices[v_ax0,T.int64(0)],v_ax1-T.int64(1):v_ax1-T.int64(1)+T.int64(2)],sample_indices[v_ax0,T.int64(0)],renorm_prob[sample_indices[v_ax0,T.int64(0)],0],indices[sample_indices[v_ax0,T.int64(0)],T.min(T.int64(0),v_ax1):T.min(T.int64(0),v_ax1)+(T.max(T.int64(0),v_ax1)+T.int64(1)-T.min(T.int64(0),v_ax1))])T.writes(output_index[v_ax0,0])ifusample[v_ax0,T.int64(0)]<cumsum_sorted[sample_indices[v_ax0,T.int64(0)],v_ax1]/renorm_prob[sample_indices[v_ax0,T.int64(0)],0]orv_ax1+T.int64(1)==vocab_size:ifv_ax1==T.int64(0):output_index[v_ax0,0]=indices[sample_indices[v_ax0,T.int64(0)],0]else:ifusample[v_ax0,T.int64(0)]>=cumsum_sorted[sample_indices[v_ax0,T.int64(0)],v_ax1-T.int64(1)]/renorm_prob[sample_indices[v_ax0,T.int64(0)],0]:output_index[v_ax0,0]=indices[sample_indices[v_ax0,T.int64(0)],v_ax1]@T.prim_func(private=True)defget_renorm_prob(A:T.handle,B:T.handle,C:T.handle,D:T.handle):batch,vocab_size=T.int64(is_size_var=True),T.int64(is_size_var=True)cumsum_sorted=T.match_buffer(A,(batch,vocab_size))top_p=T.match_buffer(B,(batch,1))top_k=T.match_buffer(C,(batch,1),"int64")renorm_prob=T.match_buffer(D,(batch,1))# with T.block("root"):forax0,ax1inT.grid(batch,vocab_size):withT.block("T_get_renorm_prob"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])T.reads(cumsum_sorted[v_ax0,T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1)):T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1))+(T.max(T.max(T.int64(0),v_ax1),v_ax1+T.int64(1))+T.int64(1)-T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1)))],top_p[v_ax0,0],top_k[v_ax0,0])T.writes(renorm_prob[v_ax0,0])ifnot(cumsum_sorted[v_ax0,0]<top_p[v_ax0,0]andtop_k[v_ax0,0]>T.int64(1)):renorm_prob[v_ax0,0]=cumsum_sorted[v_ax0,0]else:ifcumsum_sorted[v_ax0,v_ax1]<top_p[v_ax0,0]andv_ax1+T.int64(1)<top_k[v_ax0,0]:ifv_ax1+T.int64(1)==vocab_size:renorm_prob[v_ax0,0]=cumsum_sorted[v_ax0,v_ax1]else:ifnot(cumsum_sorted[v_ax0,v_ax1+T.int64(1)]<top_p[v_ax0,0]andv_ax1+T.int64(1)+T.int64(1)<top_k[v_ax0,0]):renorm_prob[v_ax0,0]=cumsum_sorted[v_ax0,v_ax1+T.int64(1)]@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeffoo(prob:R.Tensor((2,3),dtype="float32"),index:R.Tensor((2,3),dtype="int64"),top_p:R.Tensor((2,1),dtype="float32"),top_k:R.Tensor((2,1),dtype="int64"),uniform_sample:R.Tensor((3,1),dtype="float32"),sample_indices:R.Tensor((3,1),dtype="int64"),_io:R.Object,)->R.Tuple(R.Tensor((3,1),dtype="int64"),R.Tuple(R.Object)):R.func_attr({"num_input":7})cls=ExpectedwithR.dataflow():cumsum:R.Tensor((2,3),dtype="float32")=R.cumsum(prob,axis=1,dtype="void",exclusive=None)lv1=R.call_tir(cls.get_renorm_prob,(cumsum,top_p,top_k),out_sinfo=R.Tensor((2,1),dtype="float32"))lv2=R.call_tir(cls.get_index_from_sorted,(cumsum,index,lv1,uniform_sample,sample_indices),out_sinfo=R.Tensor((3,1),dtype="int64"))gv1:R.Tuple(R.Tensor((3,1),dtype="int64"),R.Tuple(R.Object))=lv2,(_io,)R.output(gv1)returngv1# fmt: onm=Model()mod,_=m.export_tvm(spec={"foo":{"prob":spec.Tensor(prob_shape,"float32"),"index":spec.Tensor(prob_shape,"int64"),"top_p":spec.Tensor((prob_shape[0],1),"float32"),"top_k":spec.Tensor((prob_shape[0],1),"int64"),"uniform_sample":spec.Tensor(sample_shape,"float32"),"sample_indices":spec.Tensor(sample_shape,"int64"),}},debug=True,)tvm.ir.assert_structural_equal(mod,Expected)target=tvm.target.Target("cuda -libs=thrust",host="llvm")withtarget:mod=tir.transform.DefaultGPUSchedule()(mod)ex=tvm.compile(mod,target)dev=tvm.cuda(0)vm=relax.VirtualMachine(ex,dev)effects=vm["_initialize_effect"]()sorted_prob=tvm.nd.array(np.array([[0.5,0.4,0.1],[0.4,0.3,0.3]]).astype(np.float32),dev)indices=tvm.nd.array(np.array([[2,1,0],[2,0,1]]).astype(np.int64),dev)top_p=tvm.nd.array(np.array([[0.6],[0.9]]).astype(np.float32),dev)top_k=tvm.nd.array(np.array([[3],[2]]).astype(np.int64),dev)usample=tvm.nd.array(np.array([[0.5],[0.6],[0.7]]).astype(np.float32),dev)sample_indices=tvm.nd.array(np.array([[0],[1],[1]]).astype(np.int64),dev)inputs=[sorted_prob,indices,top_p,top_k,usample,sample_indices,effects]res=vm["foo"](*inputs)tvm.testing.assert_allclose(res[0].numpy(),np.array([[2],[0],[0]]).astype(np.int64))@tvm.testing.requires_gpudeftest_renormalize_top_p_top_k_prob():prob_shape=(2,3)sample_shape=(2,1)classModel(Module):deffoo(self,prob:Tensor,sorted_prob:Tensor,top_p:Tensor,top_k:Tensor,):z0=op.renormalize_top_p_top_k_prob(prob,sorted_prob,top_p,top_k)returnz0# fmt: off@I.ir_moduleclassExpected:@T.prim_func(private=True)deffilter_with_top_p_top_k(A:T.Buffer((T.int64(2),T.int64(3)),"float32"),B:T.Buffer((T.int64(2),T.int64(1)),"float32"),filter_with_top_p_top_k:T.Buffer((T.int64(2),T.int64(3)),"float32")):T.func_attr({"tir.noalias":T.bool(True)})# with T.block("root"):fori,jinT.grid(T.int64(2),T.int64(3)):withT.block("filter_with_top_p_top_k"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(B[v_i,T.int64(0)],A[v_i,v_j])T.writes(filter_with_top_p_top_k[v_i,v_j])filter_with_top_p_top_k[v_i,v_j]=T.Select(B[v_i,T.int64(0)]<=A[v_i,v_j],A[v_i,v_j],T.float32(0))@T.prim_func(private=True)defget_renorm_cutoff(A:T.handle,B:T.handle,C:T.handle,D:T.handle,E:T.handle):batch,vocab_size=T.int64(),T.int64()sorted_prob=T.match_buffer(A,(batch,vocab_size))cumsum_sorted=T.match_buffer(B,(batch,vocab_size))top_p=T.match_buffer(C,(batch,1))top_k=T.match_buffer(D,(batch,1),"int64")cutoff=T.match_buffer(E,(batch,1))# with T.block("root"):forax0,ax1inT.grid(batch,vocab_size):withT.block("T_get_renorm_prob"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])T.reads(cumsum_sorted[v_ax0,T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1)):T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1))+(T.max(T.max(T.int64(0),v_ax1),v_ax1+T.int64(1))+T.int64(1)-T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1)))],top_p[v_ax0,0],top_k[v_ax0,0],sorted_prob[v_ax0,T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1)):T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1))+(T.max(T.max(T.int64(0),v_ax1),v_ax1+T.int64(1))+T.int64(1)-T.min(T.min(T.int64(0),v_ax1),v_ax1+T.int64(1)))])T.writes(cutoff[v_ax0,0])if(cumsum_sorted[v_ax0,0]<top_p[v_ax0,0]andtop_k[v_ax0,0]>T.int64(1))==T.bool(False):cutoff[v_ax0,0]=sorted_prob[v_ax0,0]else:if(cumsum_sorted[v_ax0,v_ax1]<top_p[v_ax0,0]andv_ax1+T.int64(1)<top_k[v_ax0,0])==T.bool(True):ifv_ax1+T.int64(1)==vocab_size:cutoff[v_ax0,0]=sorted_prob[v_ax0,v_ax1]else:if(cumsum_sorted[v_ax0,v_ax1+T.int64(1)]<top_p[v_ax0,0]andv_ax1+T.int64(1)+T.int64(1)<top_k[v_ax0,0])==T.bool(False):cutoff[v_ax0,0]=sorted_prob[v_ax0,v_ax1+T.int64(1)]@R.functiondef_initialize_effect()->R.Tuple(R.Object):withR.dataflow():_io:R.Object=R.null_value()lv:R.Tuple(R.Object)=(_io,)gv:R.Tuple(R.Object)=lvR.output(gv)returngv@R.functiondeffoo(prob:R.Tensor((2,3),dtype="float32"),sorted_prob:R.Tensor((2,3),dtype="float32"),top_p:R.Tensor((2,1),dtype="float32"),top_k:R.Tensor((2,1),dtype="int64"),_io:R.Object)->R.Tuple(R.Tensor((2,3),dtype="float32"),R.Tuple(R.Object)):R.func_attr({"num_input":5})cls=ExpectedwithR.dataflow():cumsum:R.Tensor((2,3),dtype="float32")=R.cumsum(sorted_prob,axis=1,dtype="void",exclusive=None)lv1=R.call_tir(cls.get_renorm_cutoff,(sorted_prob,cumsum,top_p,top_k),out_sinfo=R.Tensor((2,1),dtype="float32"))lv2=R.call_tir(cls.filter_with_top_p_top_k,(prob,lv1),out_sinfo=R.Tensor((2,3),dtype="float32"))sum:R.Tensor((2,1),dtype="float32")=R.sum(lv2,axis=[1],keepdims=True)divide:R.Tensor((2,3),dtype="float32")=R.divide(lv2,sum)gv1:R.Tuple(R.Tensor((2,3),dtype="float32"),R.Tuple(R.Object))=divide,(_io,)R.output(gv1)returngv1# fmt: onm=Model()mod,_=m.export_tvm(spec={"foo":{"prob":spec.Tensor(prob_shape,"float32"),"sorted_prob":spec.Tensor(prob_shape,"float32"),"top_p":spec.Tensor(sample_shape,"float32"),"top_k":spec.Tensor(sample_shape,"int64"),}},debug=True,)tvm.ir.assert_structural_equal(mod,Expected)target=tvm.target.Target("cuda -libs=thrust",host="llvm")withtarget:mod=relax.transform.LegalizeOps()(mod)mod=tir.transform.DefaultGPUSchedule()(mod)ex=tvm.compile(mod,target)dev=tvm.cuda(0)vm=relax.VirtualMachine(ex,dev)effects=vm["_initialize_effect"]()prob=tvm.nd.array(np.array([[0.2,0.3,0.5],[0.3,0.3,0.4]]).astype(np.float32),dev)sorted_prob=tvm.nd.array(np.array([[0.5,0.3,0.2],[0.4,0.3,0.3]]).astype(np.float32),dev)top_p=tvm.nd.array(np.array([[0.6],[0.9]]).astype(np.float32),dev)top_k=tvm.nd.array(np.array([[3],[2]]).astype(np.int64),dev)inputs=[prob,sorted_prob,top_p,top_k,effects]res=vm["foo"](*inputs)tvm.testing.assert_allclose(res[0].numpy(),np.array([[0,0.375,0.625],[0.3,0.3,0.4]]).astype(np.float32))deftest_sort_argsort_topk():classModel(Module):deffoo(self,x:Tensor):z0=op.sort(x,axis=-1,descending=True)z1=op.argsort(x,axis=-1,descending=False)z2=op.topk(x,k=2,axis=-1)returnz0,z1,z2@I.ir_moduleclassExpected:@R.functiondeffoo(x:R.Tensor(("seq_len",64),dtype="float16")):R.func_attr({"num_input":1})withR.dataflow():sort=R.sort(x,axis=-1,descending=True)argsort=R.argsort(x,axis=-1,descending=False,dtype="int32")topk=R.topk(x,k=2,axis=-1,ret_type="both",largest=True,dtype="int32")topk_0=topk[0]topk_1=topk[1]gv=sort,argsort,(topk_0,topk_1)R.output(gv)returngvm=Model()mod,_=m.export_tvm({"foo":{"x":spec.Tensor(("seq_len",64),"float16")}})tvm.ir.assert_structural_equal(mod,Expected)