deftest_stack_vm_basic():a=tvm.nd.array(np.zeros(10,dtype="float32"))@tvm.register_funcdeftvm_call_back_get_shape(shape0):print(shape0)assertshape0==a.shape[0]n=te.size_var("n")Ab=tvm.tir.decl_buffer((n,),"float32")stmt=tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape",Ab.shape[0]))mod=tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab],stmt).with_attr("global_symbol","print_shape"))run_jit(mod,lambdaf:f(a))@tvm.register_funcdeftvm_stack_vm_print(*x):print(x)deftest_stack_vm_loop():dtype="int64"n=te.size_var("n")Ab=tvm.tir.decl_buffer((n,),dtype)i=te.size_var("i")ib=tvm.tir.ir_builder.create()A=ib.buffer_ptr(Ab)withib.for_range(0,n-1,"i")asi:A[i+1]=A[i]+1ib.emit(tvm.tir.call_packed("tvm_stack_vm_print",i))stmt=ib.get()mod=tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab],stmt).with_attr("global_symbol","ramp"))a=tvm.nd.array(np.zeros(10,dtype=dtype))defcheck(f):f(a)np.testing.assert_equal(a.numpy(),np.arange(a.shape[0]))run_jit(mod,check)deftest_stack_vm_cond():dtype="int64"n=te.size_var("n")Ab=tvm.tir.decl_buffer((n,),dtype)ib=tvm.tir.ir_builder.create()A=ib.buffer_ptr(Ab)withib.for_range(0,n-1,"i")asi:withib.if_scope(tvm.tir.EQ(i,4)):A[i+1]=A[i]+1withib.else_scope():A[i+1]=A[i]+2stmt=ib.get()mod=tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab],stmt).with_attr("global_symbol","test"))defcheck(f):a=tvm.nd.array(np.zeros(10,dtype=dtype))f(a)y=np.arange(a.shape[0])*2y[5:]-=1np.testing.assert_equal(a.numpy(),y)run_jit(mod,check)deftest_vm_parallel():dtype="int64"n=te.size_var("n")Ab=tvm.tir.decl_buffer((n,),dtype)i=te.size_var("i")ib=tvm.tir.ir_builder.create()A=ib.buffer_ptr(Ab)withib.for_range(0,n,"i",kind="parallel")asi:A[i]=A[i]+1stmt=ib.get()mod=tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab],stmt).with_attr("global_symbol","test"))defcheck(f):a=tvm.nd.array(np.zeros(10,dtype=dtype))f(a)np.testing.assert_equal(a.numpy(),np.ones(a.shape[0]))run_jit(mod,check)deftest_codegen_decl_buffer():"""The codegen should accept DeclBuffer nodes in its input"""@I.ir_moduleclassmod:@T.prim_funcdefkernel(A_data:T.handle("float32")):T.func_attr({"global_symbol":"kernel"})A_buf=T.decl_buffer([256],dtype="float32",scope="global",data=A_data)target=tvm.target.Target("stackvm")stackvm_codegen=tvm.get_global_func("target.build.stackvm")stackvm_codegen(mod,target)