# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=redefined-builtin, invalid-name, too-many-arguments"""Operators used in TIR expression."""fromtypingimportAny,Optional,Unionimporttvm._ffifromtvmimporttirfromtvm.irimportArray,Op,PrimExprfromtvm.ir.baseimportSpanfromtvm.runtimeimportconstfrom.import_ffi_apifrom.bufferimportBufferfrom.exprimportCall,CommReducer,IntImm,PrimExprWithOp,Vardef_pack_buffer(buf,span=None):"""Build intrinsics that packs the buffer."""shape=Call("handle","tir.tvm_stack_make_shape",buf.shape,span)strides=Call("handle","tir.tvm_stack_make_shape",buf.strides,span)ifbuf.strideselse0pack_args=[buf.data,shape,strides,len(buf.shape),const(0,dtype=buf.dtype),buf.elem_offset,]returnCall("handle",Op.get("tir.tvm_stack_make_array"),pack_args,span)
[文档]defcall_packed_lowered(*args,span=None):"""Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray. Parameters ---------- args : list of Expr or Buffer. Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. See Also -------- te.extern : Create tensor with extern function call. """call_args=[_pack_buffer(x)ifisinstance(x,Buffer)elsexforxinargs]returnCall("int32",Op.get("tir.tvm_call_packed_lowered"),call_args,span)
[文档]defcall_cpacked_lowered(*args,span=None):"""Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle. Parameters ---------- args : list of Expr or Buffer. Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. See Also -------- te.extern : Create tensor with extern function call. """call_args=[_pack_buffer(x)ifisinstance(x,Buffer)elsexforxinargs]returnCall("int32",Op.get("tir.tvm_call_cpacked_lowered"),call_args,span)
[文档]defcall_packed(*args,span=None):"""Build expression by call an external packed function. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray. Parameters ---------- args : list of Expr or Buffer. Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. See Also -------- te.extern : Create tensor with extern function call. """call_args=[_pack_buffer(x)ifisinstance(x,Buffer)elsexforxinargs]returnCall("int32",Op.get("tir.tvm_call_packed"),call_args,span)
[文档]defcall_cpacked(*args,span=None):"""Build expression by call an external packed function. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle. Parameters ---------- args : list of Expr or Buffer. Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. See Also -------- te.extern : Create tensor with extern function call. """call_args=[_pack_buffer(x)ifisinstance(x,Buffer)elsexforxinargs]returnCall("int32",Op.get("tir.tvm_call_cpacked"),call_args,span)
[文档]defcall_intrin(dtype,func_name,*args,span=None):"""Build expression by calling an intrinsic function. Intrinsics can be overloaded with multiple data types via the intrinsic translation rule. Parameters ---------- dtype : str The data type of the result. func_name: str The intrinsic function name. args : list Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """returnCall(dtype,func_name,args,span)
[文档]defcall_pure_extern(dtype,func_name,*args,span=None):"""Build expression by calling a pure extern function. Parameters ---------- dtype : str The data type of the result. func_name: str The extern function name. args : list Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """returnCall(dtype,Op.get("tir.call_pure_extern"),[func_name,*args],span)
[文档]defcall_extern(dtype,func_name,*args,span=None):"""Build expression by calling a extern function. Parameters ---------- dtype : str The data type of the result. func_name: str The extern function name. args : list Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """returnCall(dtype,Op.get("tir.call_extern"),[func_name,*args],span=span)
[文档]defcall_llvm_intrin(dtype,name,*args,span=None):"""Build expression by calling a llvm intrinsic function Parameters ---------- dtype : str The data type of the result. name : str The name of the llvm intrinsic function. args : list Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """# pylint: disable=import-outside-toplevelfromtvm.targetimportcodegenifisinstance(name,str):llvm_id=codegen.llvm_lookup_intrinsic_id(name)elifisinstance(name,IntImm):llvm_id=name.valueelse:llvm_id=nameifllvm_id==0:raiseValueError(f"Unknown llvm intrinsic function {name}")returncall_intrin(dtype,Op.get("tir.call_llvm_intrin"),tvm.tir.const(llvm_id,"uint32"),*args,span=span,)
[文档]defcall_llvm_pure_intrin(dtype,name,*args,span=None):"""Build expression by calling a pure llvm intrinsic function Parameters ---------- dtype : str The data type of the result. name : str The name of the llvm intrinsic function. args : list Positional arguments. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """# pylint: disable=import-outside-toplevelfromtvm.targetimportcodegenifisinstance(name,str):llvm_id=codegen.llvm_lookup_intrinsic_id(name)elifisinstance(name,IntImm):llvm_id=name.valueelse:llvm_id=nameifllvm_id==0:raiseValueError(f"Unknown llvm intrinsic function {name}")returncall_intrin(dtype,Op.get("tir.call_llvm_pure_intrin"),tvm.tir.const(llvm_id,"uint32"),*args,span=span,)
[文档]deftvm_check_return(expected,return_unexpected,nested_call):"""Return new on stack dtype[num] Parameters ---------- expected : int The expected return code. return_unexpected : int The unexpected return code. nested_call : PrimExpr The call expression to check return. Returns ------- call : PrimExpr The call expression. """returncall_intrin("int32","tir.tvm_check_return",expected,return_unexpected,nested_call)
[文档]deftvm_stack_alloca(dtype_str,num):"""Return new on stack dtype[num] Parameters ---------- dtype_str : str The data type of array. num : int The size of array. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_stack_alloca",dtype_str,num)
[文档]deftvm_stack_make_shape(*args):"""Allocate a shape tuple on stack, return the handle Parameters ---------- args : int The tuple shape. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_stack_make_shape",*args)
[文档]deftvm_stack_make_array(data,shape,strides,ndim,arr_dtype,elem_offset):"""Allocate a NDArray(DLTensor) on stack, return the handle Parameters ---------- data : Expr The data of array. shape : Expr The shape of array. strides : Expr The strides of array. ndim : Expr The dimensions of array. arr_dtype : Expr The data type of array. elem_offse : Expr The element offset of array. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_stack_make_array",data,shape,strides,ndim,arr_dtype,elem_offset)
[文档]defassume(cond=None):"""Provide a true statement that can be used for simplifications Parameters ---------- cond : Expr The constraint condition. Returns ------- call : PrimExpr The call expression. """returncall_intrin("bool","tir.assume",cond)
[文档]defundef():"""Returns an initialized but arbitrary value Returns ------- call : PrimExpr The call expression. """returncall_intrin("int32","tir.undef")
[文档]defcall_tir(global_var:tvm.ir.GlobalVar,*args):"""Performs a call into another PrimFunc in the same IRModule Returns ------- call : PrimExpr The call expression. """assertisinstance(global_var,tvm.ir.GlobalVar)dtype="void"ifglobal_var.checked_typeisnotNone:ret_type=global_var.checked_type.ret_typeifhasattr(ret_type,"dtype"):dtype=ret_type.dtypereturnCall(dtype=dtype,op=global_var,args=args)
[文档]defstart_profile_intrinsic(id):"""Start profile intrinsic. Parameters ---------- id : int The intrinsic id. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.start_profile_intrinsic",id)
[文档]defend_profile_intrinsic(id):"""End profile intrinsic. Parameters ---------- id : int The intrinsic id. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.end_profile_intrinsic",id)
[文档]deftvm_tuple(*value):"""Create a tuple structure in value field of AttrStmt Parameters ---------- value : Expr The value in tuple. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_tuple",*value)
[文档]deftvm_struct_get(arr,index,field,dtype):"""Get struct field value in array Parameters ---------- dtype : str The date type of the result. arr : StructType* The array of struct. index : int The index of struct. field : int The field of struct. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.tvm_struct_get",arr,index,field)
[文档]deftvm_struct_set(arr,index,field,value):"""Set value in struct field in array Parameters ---------- arr : StructType* The array of struct. index : int The index of struct. field : int The field of struct. value : Expr The value to be set in field. Returns ------- call : PrimExpr The call expression. """returncall_intrin("int32","tir.tvm_struct_set",arr,index,field,value)
[文档]defaddress_of(buffer_load,span=None):"""Returns the address of an element in the buffer Parameters ---------- buffer_load: BufferLoad The buffer load. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.address_of",buffer_load,span=span)
[文档]deflookup_param(param_name,span=None):"""Returns the param by name Parameters ---------- param_name : str The name of param. span : Optional[Span] The location of this operator in the source code. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.lookup_param",param_name,span=span)
[文档]deftvm_thread_allreduce(*freduce_args):"""Perform allreduce inside threadblock. Parameters ---------- freduce_args : Expr The args. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_thread_allreduce",*freduce_args)
deftvm_thread_invariant(cond):"""Mark condition as thread invariant. Parameters ---------- cond : Expr The condition. Returns ------- call : PrimExpr The call expression. """assertisinstance(cond,PrimExpr)returncall_intrin(cond.dtype,"tir.tvm_thread_invariant",cond)deftvm_storage_sync(storage_scope):"""Perform synchronization in specified scope. Parameters ---------- storage_scope : str The storage scope to perform synchronization. Returns ------- call : PrimExpr The call expression. """returncall_intrin("int32","tir.tvm_storage_sync",storage_scope)deftvm_warp_shuffle(mask,value,warp_id,width,warp_size):"""Exchange value between threads inside a warp. Parameters ---------- mask : PrimExpr The warp mask indicates active threads inside warp. value : PrimExpr The value to exchange. warp_id : PrimExpr The source lane index to fetch value. width : PrimExpr The width of sub-sections to perform warp shuffle. warp_size : PrimExpr The warp size. Returns ------- call : PrimExpr The call expression. """returncall_intrin(value.dtype,"tir.tvm_warp_shuffle",mask,value,warp_id,width,warp_size)deftvm_warp_shuffle_up(mask,value,offset,width,warp_size):"""Copy value from a lane with lower (by offset) index relative to caller. Parameters ---------- mask : PrimExpr The warp mask indicates active threads inside warp. value : PrimExpr The value to exchange. offset : PrimExpr The difference between source lane index and destination lane index: `offset = dst_lane_idx - src_lane_idx` width : PrimExpr The width of sub-sections to perform warp shuffle. warp_size : PrimExpr The warp size. Returns ------- call : PrimExpr The call expression. """returncall_intrin(value.dtype,"tir.tvm_warp_shuffle_up",mask,value,offset,width,warp_size)deftvm_warp_shuffle_down(mask,value,offset,width,warp_size):"""Copy value from a lane with higher (by offset) index relative to caller. Parameters ---------- mask : PrimExpr The warp mask indicates active threads inside warp. value : PrimExpr The value to exchange. offset : PrimExpr The difference between source lane index and destination lane index: `offset = src_lane_idx - dst_lane_idx` width : PrimExpr The width of sub-sections to perform warp shuffle. warp_size : PrimExpr The warp size. Returns ------- call : PrimExpr The call expression. """returncall_intrin(value.dtype,"tir.tvm_warp_shuffle_down",mask,value,offset,width,warp_size)deftvm_warp_activemask():"""Return a 32-bit mask indicates currently active threads in a calling warp. Returns ------- call : PrimExpr The call expression. """returncall_intrin("uint32","tir.tvm_warp_activemask")
[文档]deftype_annotation(dtype):"""Create a type annotation expression Parameters ---------- dtype : Expr The data type. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.type_annotation")
[文档]deftvm_access_ptr(ptype,data,offset,extent,rw_mask):"""Get head access address with memory access pattern info Parameters ---------- ptype : Expr The data type of pointer. data : DType* The data of pointer. offset : int The offset of pointer. extent : int The extent of pointer. rw_mask : int The read write mask. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_access_ptr",ptype,data,offset,extent,rw_mask)
[文档]deftvm_throw_last_error():"""Throw TVMGetLastError() Returns ------- ret : PrimExpr The return expression """returncall_intrin("handle","tir.tvm_throw_last_error")
[文档]deftvm_load_matrix_sync(fragment,m,n,k,index,buffer_ptr,stride,layout):"""TVM intrinsic for tensor core load operators Parameters ---------- fragment : Var The wmma fragment. m : UIntImm The shape of wmma fragment. n : UIntImm The shape of wmma fragment. k : UIntImm The shape of wmma fragment. index : Expr The fragment index. buffer_ptr : Expr The fragment buffer pointer. stride : Expr The fragment stride. layout : Literal["row_major", "column_major"] The fragment layout. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_load_matrix_sync",fragment,m,n,k,index,buffer_ptr,stride,layout,)
[文档]deftvm_mma_sync(fragment_d,index_d,fragment_a,index_a,fragment_b,index_b,fragment_c,index_c):"""TVM intrinsic for tensor core mma_sync operators Parameters ---------- fragment_d : Var The wmma fragment_d. index_d : Expr The fragment_d index. fragment_a : Var The wmma fragment_a. index_a : Expr The fragment_a index. fragment_b : Var The wmma fragment_b. index_b : Expr The fragment_b index. fragment_c : Var The wmma fragment_c. index_c : Expr The fragment_c index. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_mma_sync",fragment_d,index_d,fragment_a,index_a,fragment_b,index_b,fragment_c,index_c,)
[文档]deftvm_bmma_sync(fragment_d,index_d,fragment_a,index_a,fragment_b,index_b,fragment_c,index_c):"""TVM intrinsic for tensor core bmma_sync operators Parameters ---------- fragment_d : Var The bwmma fragment_d. index_d : Expr The fragment_d index. fragment_a : Var The bwmma fragment_a. index_a : Expr The fragment_a index. fragment_b : Var The bwmma fragment_b. index_b : Expr The fragment_b index. fragment_c : Var The bwmma fragment_c. index_c : Expr The fragment_c index. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_bmma_sync",fragment_d,index_d,fragment_a,index_a,fragment_b,index_b,fragment_c,index_c,)
[文档]deftvm_fill_fragment(fragment,m,n,k,index,value):"""TVM intrinsic for tensor core fill_fragment operators Parameters ---------- fragment : Var The wmma fragment m : UIntImm The shape of wmma fragment. n : UIntImm The shape of wmma fragment. k : UIntImm The shape of wmma fragment. index : Expr The fragment index. value : Expr The value to be filled in fragment. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_fill_fragment",fragment,m,n,k,index,value,)
[文档]deftvm_store_matrix_sync(fragment,m,n,k,index,buffer_ptr,stride,layout):"""TVM intrinsic for tensor core store operators Parameters ---------- fragment : Var The wmma fragment. m : UIntImm The shape of wmma fragment. n : UIntImm The shape of wmma fragment. k : UIntImm The shape of wmma fragment. index : Expr The fragment index. buffer_ptr : Expr The fragment buffer pointer. stride : Expr The fragment stride. layout : Literal["row_major", "column_major"] The fragment layout. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.tvm_store_matrix_sync",fragment,m,n,k,index,buffer_ptr,stride,layout,)
[文档]defptx_mma(dtype,shape,A_layout,B_layout,A_dtype,B_dtype,C_dtype,multiplicand_a,a_index,multiplicand_b,b_index,accumulator,c_index,saturate,operator=None,):"""TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma Parameters ---------- dtype : str The data type of the result. shape : str The shape of mma fragment. A_layout : Literal["row", "col"] The layout of multiplicand fragment A. B_layout : Literal["row", "col"] The layout of multiplicand fragment B. A_dtype : str The data type of multiplicand fragment A. B_dtype : str The data type of multiplicand fragment B. C_dtype : str The data type of accumulator fragment C. multiplicand_a : Var The multiplicand fragment A variable. a_index : Expr The index of multiplicand fragment A. multiplicand_b : Var The multiplicand fragment B variable. b_index : Expr The index of multiplicand fragment A. accumulator : Var The accumulator fragment C variable. c_index : Expr The index of accumulator fragment C. saturate : bool The optional saturation at the output. operator : Optional[Literal["xor", "and"]] The 1-bit operator. Returns ------- call : PrimExpr The call expression. """ifoperatorisNone:returncall_intrin(dtype,"tir.ptx_mma",shape,A_layout,B_layout,A_dtype,B_dtype,C_dtype,multiplicand_a,a_index,multiplicand_b,b_index,accumulator,c_index,saturate,)returncall_intrin(dtype,"tir.ptx_mma",shape,A_layout,B_layout,A_dtype,B_dtype,C_dtype,multiplicand_a,a_index,multiplicand_b,b_index,accumulator,c_index,saturate,operator,)
[文档]defptx_mma_sp(dtype,shape,A_layout,B_layout,A_dtype,B_dtype,C_dtype,multiplicand_a,a_index,multiplicand_b,b_index,accumulator,c_index,metadata,meta_index,sparse_selector,saturate,):"""TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma Parameters ---------- dtype : str The data type of the result. shape : str The shape of mma fragment. A_layout : Literal["row", "col"] The layout of multiplicand fragment A. B_layout : Literal["row", "col"] The layout of multiplicand fragment B. A_dtype : str The data type of multiplicand fragment A. B_dtype : str The data type of multiplicand fragment B. C_dtype : str The data type of multiplicand fragment C. multiplicand_a : Var The multiplicand fragment A variable. a_index : Expr The index of multiplicand fragment A. multiplicand_b : Var The multiplicand fragment B variable. b_index : Expr The index of multiplicand fragment B. accumulator : Var The accumulator fragment C variable. c_index : Expr The index of accumulator fragment C. metadata : Expr The metadata of operand. meta_index : Expr The metadata index of operand. sparse_selector : Expr The sparse selector indicating the thread that stores the metadata. saturate : bool The optional saturation at the output. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.ptx_mma_sp",shape,A_layout,B_layout,A_dtype,B_dtype,C_dtype,multiplicand_a,a_index,multiplicand_b,b_index,accumulator,c_index,metadata,meta_index,sparse_selector,saturate,)
[文档]defmma_store(dtype,m,n,dst_ptr,src_ptr,src_offset,dst_stride):"""TVM intrinsic for storing the result of PTX MMA into a destination pointer Parameters ---------- dtype : str The data type of the result. m : IntImm The shape of mma fragment. n : IntImm The shape of mma fragment. dst_ptr : Var The destination pointer variable. src_ptr : Var The source pointer variable. src_offset : Expr The source offset. dst_stride : Var The destination stride. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.mma_store",m,n,dst_ptr,src_ptr,src_offset,dst_stride,)
[文档]defmma_fill(dtype,local_size,local_ptr,offset):"""TVM intrinsic for zero-initalizing an MMA accumulation registor Parameters ---------- dtype : str The data type of the result. local_size : IntImm The number of elements. local_ptr : Var The destination pointer variable. offset : Expr The destination offset. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.mma_fill",local_size,local_ptr,offset,)
[文档]defptx_ldmatrix(dtype,trans,num,type,local_ptr,local_offset,smem_ptr,smem_offset):"""TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix Parameters ---------- dtype : str The data type of the result. trans : bool The matrix is loaded in column-major format. num : IntImm The number of matrices. type : Literal[".b16"] The data type of the matrices. local_ptr : Var The local pointer variable. local_offset : Expr The offset of local pointer. smem_ptr : Var The shared memory pointer variable. smem_offset : Expr The offset of shared memort pointer. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.ptx_ldmatrix",trans,num,type,local_ptr,local_offset,smem_ptr,smem_offset,)
[文档]defptx_cp_async(dtype,shared_ptr,shared_offset,global_ptr,global_offset,bytes):"""TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async Parameters ---------- dtype : str The data type of the result. shared_ptr : Var The shared memory pointer variable. shared_offset : Expr The offset of shared memory pointer. global_ptr : Var The global memory pointer variable. global_offset : Expr The offset of global memory pointer. bytes : int The data size to copy. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.ptx_cp_async",shared_ptr,shared_offset,global_ptr,global_offset,bytes)
[文档]defptx_cp_async_bulk(dtype,shared_ptr,shared_offset,global_ptr,global_offset,bytes,barrier_id):"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk Parameters ---------- dtype : str The data type of the result. shared_ptr : Var The shared memory pointer variable. shared_offset : Expr The offset of shared memory pointer. global_ptr : Var The global memory pointer variable. global_offset : Expr The offset of global memory pointer. bytes : int The data size to copy. barrier_id : int The ID of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.ptx_cp_async_bulk",shared_ptr,shared_offset,global_ptr,global_offset,bytes,barrier_id,)
[文档]defptx_commit_group():"""TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_commit_group")
[文档]defptx_wait_group(num):"""TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group Parameters ---------- num : int The number of the most recent uncommitted pending cp.async groups to wait. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_wait_group",num)
[文档]defptx_cp_async_barrier(barrier_id):"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive Parameters ---------- barrier_id : int The ID of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_cp_async_barrier",barrier_id)
[文档]defptx_init_barrier_thread_count(barrier_id,thread_count):"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init Parameters ---------- barrier_id : int The ID of the barrier shared memory pointer. thread_count : int Number of threads expected to arrive at the barrier. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_init_barrier_thread_count",barrier_id,thread_count)
[文档]defptx_arrive_barrier(barrier_id):"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive Parameters ---------- barrier_id : int The ID of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_arrive_barrier",barrier_id)
[文档]defptx_arrive_barrier_expect_tx(barrier_id,byte_count):"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation Parameters ---------- barrier_id : int The ID of the barrier shared memory pointer. byte_count : int Increases the tx count of the mbarrier object to track completion of addtional async transactions. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_arrive_barrier_expect_tx",barrier_id,byte_count)
[文档]defptx_wait_barrier(barrier_id):"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait Parameters ---------- barrier_id : int The ID of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.ptx_wait_barrier",barrier_id)
[文档]defcreate_barriers(barrier_count):"""TVM intrinsic to create N barriers Parameters ---------- barrier_count : int The number of barriers to create. Returns ------- call : PrimExpr The call expression. """returncall_intrin("","tir.create_barriers",barrier_count)
[文档]defmake_filled_simdgroup_matrix(d:Var,index:PrimExpr,value:PrimExpr,col:int=8,row:int=8,):"""Create a filled SIMDGroup matrix Parameters ---------- d : var The simdgroup var index : PrimExpr The index of the matrix. value : PrimExpr The value to fill. col : int The number of columns. row : int The number of rows. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.make_filled_simdgroup_matrix",d,index,value,col,row)
[文档]defsimdgroup_load(d:Var,index:PrimExpr,ptr:PrimExpr,stride:PrimExpr,col:int=8,row:int=8,transpose_matrix:bool=False,):"""Load data from device memory or threadgroup memory to simdgroup Parameters ---------- d : var The simdgroup var index : PrimExpr The index of the matrix. ptr : PrimExpr The pointer. stride : PrimExpr The stride. col : int The number of columns. row : int The number of rows. transpose_matrix : bool Whether to transpose the matrix. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.simdgroup_load",d,index,ptr,stride,col,row,transpose_matrix,)
[文档]defsimdgroup_store(d:PrimExpr,index:PrimExpr,ptr:PrimExpr,stride:PrimExpr,col:int=8,row:int=8,transpose_matrix:bool=False,):"""Store data from simdgroup to device memory or threadgroup memory Parameters ---------- d : PrimExpr The SIMDGroup. index : PrimExpr The index of the matrix. ptr : PrimExpr The pointer. stride : PrimExpr The stride. col : int The number of columns. row : int The number of rows. transpose_matrix : bool Whether to transpose the matrix. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.simdgroup_store",d,index,ptr,stride,col,row,transpose_matrix)
[文档]defsimdgroup_multiply_accumulate(d:Var,index_d:PrimExpr,a:Var,index_a:PrimExpr,b:Var,index_b:PrimExpr,c:Var,index_c:PrimExpr,):"""Multiply and accumulate two matrices in simdgroup i.e. d = a * b + c Parameters ---------- d : Var The destination matrix. index_d : PrimExpr The index of the destination matrix. a : Var The first matrix. index_a : PrimExpr The index of the first matrix. b : Var The second matrix. index_b : PrimExpr The index of the second matrix. c : Var The third matrix. index_c : PrimExpr The index of the third matrix. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.simdgroup_multiply_accumulate",d,index_d,a,index_a,b,index_b,c,index_c,)
[文档]defvectorlow(dtype,vec):"""Get the low level half of the vector Parameters ---------- dtype : str The data type of the result. vec : list The input vector. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.vectorlow",vec)
[文档]defvectorhigh(dtype,vec):"""Get the high level half of the vector Parameters ---------- dtype : str The data type of the result. vec : list The input vector. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.vectorhigh",vec)
[文档]defvectorcombine(dtype,vec1,vec2):"""Concat two vectors Parameters ---------- vec1 : list The input vector. vec2 : list The input vector. Returns ------- call : PrimExpr The call expression. """returncall_intrin(dtype,"tir.vectorcombine",vec1,vec2)
[文档]defdp4a(vec1,vec2,acc=0):"""Dot product of two int8x4 vectors and add an optional accumulator Parameters ---------- vec1 : int8x4 The input vector. vec2 : int8x4 The input vector. acc : int32 The accumulator. Returns ------- call : PrimExpr The call expression. """returncall_intrin("int32","tir.dp4a",vec1,vec2,acc)
[文档]defret(val,span=None):"""Create a tir return expression Parameters ---------- val : Expr The returned tir expression, whose data type is int, float or void pointer. span : Optional[Span] The location of this operator in the source code. Returns ------- ret : PrimExpr The return expression """return_ffi_api.ret(val,span)
[文档]defany(*args,span=None):"""Create a new experssion of the union of all conditions in the arguments Parameters ---------- args : list List of symbolic boolean expressions span : Optional[Span] The location of this operator in the source code. Returns ------- expr: Expr Expression """ifnotargs:raiseValueError("Any must take at least 1 argument")iflen(args)==1:returnargs[0]val=_ffi_api._OpOr(args[0],args[1],span)# type: ignoreforiinrange(2,len(args)):val=_ffi_api._OpOr(val,args[i],span)# type: ignorereturnval
[文档]defall(*args,span=None):"""Create a new expression of the intersection of all conditions in the arguments Parameters ---------- args : list List of symbolic boolean expressions span : Optional[Span] The location of this operator in the source code. Returns ------- expr: Expr Expression """ifnotargs:raiseValueError("Any must take at least 1 argument")iflen(args)==1:returnargs[0]val=_ffi_api._OpAnd(args[0],args[1],span)# type: ignoreforiinrange(2,len(args)):val=_ffi_api._OpAnd(val,args[i],span)# type: ignorereturnval
[文档]deftrace(args,trace_action="tvm.default_trace_action"):"""Trace tensor data at the runtime. The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used. Parameters ---------- args : list of Expr or Buffers. Positional arguments. trace_action : str. The name of the trace action. Returns ------- call : PrimExpr The call expression. See Also -------- tvm.tir.call_packed : Creates packed function. """ifnotisinstance(args,list):raiseException("tvm.tir.trace consumes the args as list type")call_args=[_pack_buffer(x)ifisinstance(x,Buffer)elsexforxinargs]call_args.insert(0,trace_action)returntvm.tir.Call(args[-1].dtype,Op.get("tir.tvm_call_trace_packed"),call_args)
[文档]defmin_value(dtype,span=None):"""minimum value of dtype Parameters ---------- dtype : str The data type. span : Optional[Span] The location of this operator in the source code. Returns ------- value : tvm.Expr The minimum value of dtype. """return_ffi_api.min_value(dtype,span)# type: ignore
[文档]defmax_value(dtype:str,span:Optional[Span]=None)->Any:"""maximum value of dtype Parameters ---------- dtype : str The data type. span : Optional[Span] The location of this operator in the source code. Returns ------- value : tvm.Expr The maximum value of dtype. """return_ffi_api.max_value(dtype,span)# type: ignore
[文档]definfinity(dtype:str,span:Optional[Span]=None)->Any:"""infinity value of dtype Parameters ---------- dtype : str The data type. span : Optional[Span] The location of this operator in the source code. Returns ------- value : tvm.Expr The infinity value of dtype. """return_ffi_api.infinity(dtype,span)# type: ignore
[文档]defreinterpret(dtype,value,span:Optional[Span]=None)->Any:"""infinity value of dtype Parameters ---------- dtype : str The data type. value : PrimExpr The input value. span : Optional[Span] The location of this operator in the source code. Returns ------- value : tvm.Expr The reinterpret cast value of dtype. """return_ffi_api.reinterpret(dtype,value,span)# type: ignore
[文档]defexp(x):"""Take exponential of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.exp",x)
[文档]defexp2(x):"""Calculate 2**x Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.exp2",x)
[文档]defexp10(x):"""Calculate 10**x Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.exp10",x)
[文档]deferf(x):"""Take gauss error function of the input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.erf",x)
[文档]deftanh(x):"""Take hyperbolic tanh of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.tanh",x)
[文档]defsigmoid(x):"""Quick function to get sigmoid Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.sigmoid",x)
[文档]deflog(x):"""Take log of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.log",x)
[文档]deflog2(x):"""Take log2 of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.log2",x)
[文档]deflog10(x):"""Take log10 of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.log10",x)
[文档]deflog1p(x):"""Take log(x + 1) with respect to input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.log1p",x)
[文档]deftan(x):"""Take tan of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.tan",x)
[文档]defcos(x):"""Take cos of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.cos",x)
[文档]defcosh(x):"""Take cosh of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.cosh",x)
[文档]defacos(x):"""Take acos of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.acos",x)
[文档]defacosh(x):"""Take acos of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.acosh",x)
[文档]defsin(x):"""Take sin of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.sin",x)
[文档]defsinh(x):"""Take sinh of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.sinh",x)
[文档]defasin(x):"""Take asin of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.asin",x)
[文档]defasinh(x):"""Take asinh of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.asinh",x)
[文档]defatan(x):"""Take atan of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.atan",x)
[文档]defatanh(x):"""Take atanh of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.atanh",x)
[文档]defsqrt(x):"""Take square root of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.sqrt",x)
[文档]defrsqrt(x):"""Take reciprocal of square root of input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.rsqrt",x)
[文档]defclz(x):"""Count leading zero bits of an integer x. Parameters ---------- x : PrimExpr Input 32 or 64 bit integer. The result is undefined if the input is 0. Returns ------- y : PrimExpr The result. """returncall_intrin("int32","tir.clz",x)
[文档]deffloor(x:PrimExprWithOp,span=None):"""Take floor of float input x. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.floor(x,span)# type: ignore
[文档]defceil(x,span=None):"""Take ceil of float input x. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.ceil(x,span)# type: ignore
[文档]deftrunc(x,span=None):"""Get truncated value of the input. The truncated value of the scalar x is the nearest integer i which is closer to zero than x is. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.trunc(x,span)# type: ignore
[文档]defabs(x,span=None):"""Get absolute value of the input element-wise. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.abs(x,span)# type: ignore
[文档]defbitwise_and(x,y,span=None):"""Take bitwise and of two values Parameters ---------- x : PrimExpr Left operand y : PrimExpr Right operand span : Optional[Span] The location of this operator in the source code. Returns ------- res : PrimExpr The result. """return_ffi_api.bitwise_and(x,y,span)
[文档]defbitwise_not(x,span=None):"""Take bitwise not of input value Parameters ---------- x : PrimExpr Input operand span : Optional[Span] The location of this operator in the source code. Returns ------- res : PrimExpr The result. """return_ffi_api.bitwise_not(x,span)
[文档]defbitwise_or(x,y,span=None):"""Take bitwise or of two values Parameters ---------- x : PrimExpr Left operand y : PrimExpr Right operand span : Optional[Span] The location of this operator in the source code. Returns ------- res : PrimExpr The result. """return_ffi_api.bitwise_or(x,y,span)
[文档]defbitwise_xor(x,y,span=None):"""Take bitwise xor of two values Parameters ---------- x : PrimExpr Left operand y : PrimExpr Right operand span : Optional[Span] The location of this operator in the source code. Returns ------- res : PrimExpr The result. """return_ffi_api.bitwise_xor(x,y,span)
[文档]defround(x,span=None):"""Round elements of the array to the nearest integer. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.round(x,span)# type: ignore
[文档]defnearbyint(x,span=None):"""Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.nearbyint(x,span)# type: ignore
[文档]defnextafter(x1,x2):"""Return the next floating-point value after x1 towards x2. Parameters ---------- x1 : PrimExpr Input argument. x2 : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x1=tir.convert(x1)x2=tir.convert(x2)returncall_intrin(x1.dtype,"tir.nextafter",x1,x2)# type: ignore
[文档]defhypot(x1,x2):"""Equivalent to sqrt(x1**2 + x2**2), element-wise. Parameters ---------- x1 : PrimExpr Input argument. x2 : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x1=tir.convert(x1)x2=tir.convert(x2)returncall_intrin(x1.dtype,"tir.hypot",x1,x2)# type: ignore
[文档]defcopysign(x1,x2):"""Change the sign of x1 to that of x2, element-wise. Parameters ---------- x1 : PrimExpr Input argument. x2 : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x1=tir.convert(x1)x2=tir.convert(x2)returncall_intrin(x1.dtype,"tir.copysign",x1,x2)# type: ignore
[文档]deflikely(cond,span=None):"""Mark condition as likely. Parameters ---------- cond : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The marked expression. """return_ffi_api.likely(cond,span)# type: ignore
[文档]defisnan(x,span=None):"""Check if input value is Nan. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.isnan(x,span)# type: ignore
[文档]defisnullptr(x,span=None):"""Check if input value is nullptr. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """returncall_intrin("bool","tir.isnullptr",x,span=span)# type: ignore
[文档]defisfinite(x,span=None):"""Check if input value is finite. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.isfinite(x,span)# type: ignore
[文档]defisinf(x,span=None):"""Check if input value is infinite. Parameters ---------- x : PrimExpr Input argument. span : Optional[Span] The location of this operator in the source code. Returns ------- y : PrimExpr The result. """return_ffi_api.isinf(x,span)# type: ignore
[文档]defpower(x,y,span=None):"""x power y Parameters ---------- x : PrimExpr Input argument. y : PrimExpr The exponent span : Optional[Span] The location of this operator in the source code. Returns ------- z : PrimExpr The result. """return_ffi_api._OpPow(x,y,span)# type: ignore
[文档]defpow(x,y,span=None):"""x power y Parameters ---------- x : PrimExpr Input argument. y : PrimExpr The exponent span : Optional[Span] The location of this operator in the source code. Returns ------- z : PrimExpr The result. """return_ffi_api._OpPow(x,y,span)# type: ignore
[文档]defpopcount(x):"""Count the number of set bits in input x. Parameters ---------- x : PrimExpr Input argument. Returns ------- y : PrimExpr The result. """x=tir.convert(x)returncall_intrin(x.dtype,"tir.popcount",x)
[文档]defq_multiply_shift(x,y,q,s):"""Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is: out = round(x*y*2^-s) More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1) Parameters ---------- x : PrimExpr First Q-number y : PrimExpr Second Q-number q : PrimExpr Number of fractional bits in x and y. Needs to be > 0 s : PrimExpr Integer shift Returns ------- y : PrimExpr The result. """returncall_intrin("int32","tir.q_multiply_shift",x,y,q,s)
[文档]defq_multiply_shift_per_axis(x:PrimExpr,y:PrimExpr,ls:PrimExpr,rs:PrimExpr,q:IntImm,is_lshift_required:IntImm,is_rshift_required:IntImm,):"""Execute a multiplication between two Q-numbers x and y Parameters ---------- x : PrimExpr First Q-number. y : PrimExpr Second Q-number. ls : PrimExpr Integer left shift. rs : PrimExpr Integer right shift. q : IntImm Number of fractional bits in x and y. Needs to be > 0. is_lshift_required : IntImm Whether we need to do left shift or not. is_rshift_required : IntImm Whether we need to do right shift or not. Returns ------- z : PrimExpr The result. """returncall_intrin("int32","tir.q_multiply_shift_per_axis",x,y,ls,rs,q,is_lshift_required,is_rshift_required,)
[文档]defshift_left(x,y,span=None):"""Return the result of x left shifted by y bits. Parameters ---------- x : PrimExpr Input argument. y : PrimExpr Input argument. Returns ------- z : PrimExpr The result. """return_ffi_api.left_shift(x,y,span)
[文档]defshift_right(x,y,span=None):"""Return the result of x right shifted by y bits. Parameters ---------- x : PrimExpr Input argument. y : PrimExpr Input argument. Returns ------- z : PrimExpr The result. """return_ffi_api.right_shift(x,y,span)
[文档]deffmod(x,y):"""Return the remainder of x divided by y with the same sign as x. Parameters ---------- x : PrimExpr Input argument. y : PrimExpr Input argument. Returns ------- z : PrimExpr The result. """x=tir.convert(x)y=tir.convert(y)returncall_intrin(x.dtype,"tir.fmod",x,y)
[文档]defif_then_else(cond,t,f,span=None):"""Conditional selection expression. Parameters ---------- cond : PrimExpr The condition t : PrimExpr The result expression if cond is true. f : PrimExpr The result expression if cond is false. span : Optional[Span] The location of this operator in the source. Returns ------- result : Node The result of conditional expression. Note ---- Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """return_ffi_api._OpIfThenElse(cond,t,f,span)# type: ignore
[文档]defdiv(a,b,span=None):"""Compute a / b as in C/C++ semantics. Parameters ---------- a : PrimExpr The left hand operand, known to be non-negative. b : PrimExpr The right hand operand, known to be non-negative. span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. Note ---- When operands are integers, returns truncdiv(a, b, span). """return_ffi_api._OpDiv(a,b,span)# type: ignore
[文档]defindexdiv(a,b,span=None):"""Compute floor(a / b) where a and b are non-negative. Parameters ---------- a : PrimExpr The left hand operand, known to be non-negative. b : PrimExpr The right hand operand, known to be non-negative. span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. Note ---- Use this function to split non-negative indices. This function may take advantage of operands' non-negativeness. """return_ffi_api._OpIndexDiv(a,b,span)# type: ignore
[文档]defindexmod(a,b,span=None):"""Compute the remainder of indexdiv. a and b are non-negative. Parameters ---------- a : PrimExpr The left hand operand, known to be non-negative. b : PrimExpr The right hand operand, known to be non-negative. span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. Note ---- Use this function to split non-negative indices. This function may take advantage of operands' non-negativeness. """return_ffi_api._OpIndexMod(a,b,span)# type: ignore
[文档]deftruncdiv(a,b,span=None):"""Compute the truncdiv of two expressions. Parameters ---------- a : PrimExpr The left hand operand b : PrimExpr The right hand operand span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. Note ---- This is the default integer division behavior in C. """return_ffi_api._OpTruncDiv(a,b,span)# type: ignore
[文档]deftruncmod(a,b,span=None):"""Compute the truncmod of two expressions. Parameters ---------- a : PrimExpr The left hand operand b : PrimExpr The right hand operand span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. Note ---- This is the default integer division behavior in C. """return_ffi_api._OpTruncMod(a,b,span)# type: ignore
[文档]deffloordiv(a,b,span=None):"""Compute the floordiv of two expressions. Parameters ---------- a : PrimExpr The left hand operand b : PrimExpr The right hand operand span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. """return_ffi_api._OpFloorDiv(a,b,span)# type: ignore
[文档]deffloormod(a,b,span=None):"""Compute the floormod of two expressions. Parameters ---------- a : PrimExpr The left hand operand b : PrimExpr The right hand operand span : Optional[Span] The location of this operator in the source. Returns ------- res : PrimExpr The result expression. """return_ffi_api._OpFloorMod(a,b,span)# type: ignore
[文档]defceildiv(lhs,rhs,span=None):"""Generic ceildiv operator. Parameters ---------- lhs : object The left operand. rhs : object The right operand. span : Optional[Span] The location of this operator in the source. Returns ------- op : tvm.Expr The result Expr of ceildiv operaton. """return_ffi_api._OpCeilDiv(lhs,rhs,span)# type: ignore
[文档]defcomm_reducer(fcombine,fidentity,name="reduce"):"""Create a commutative reducer for reduction. Parameters ---------- fcombine : function(Expr -> Expr -> Expr) A binary function which takes two Expr as input to return a Expr. fidentity : function(str -> Expr) A function which takes a type string as input to return a const Expr. Returns ------- reducer : function A function which creates a reduce expression over axis. There are two ways to use it: 1. accept (expr, axis, where) to produce an Reduce Expr on specified axis; 2. simply use it with multiple Exprs. Example ------- .. code-block:: python n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B") """def_reduce_directly(*args):num=len(args)# process `where` is Noneifnum==3andargs[2]isNone:num=2res=args[0]foriinrange(num-1):res=fcombine(res,args[i+1])returnresdef_make_reduce(expr,axis,where=None,init=None):code=fcombine.__code__assertfcombine.__code__.co_argcount==2expr=tir.convert(expr)ifinitisnotNone:init=tir.convert(init)ifisinstance(expr,Array):size=len(expr)lhs=[]rhs=[]dtypes=[]foriinrange(size):dtype=expr[i].dtypedtypes.append(dtype)lname=code.co_varnames[0]+"_"+str(i)lhs.append(Var(lname,dtype))rname=code.co_varnames[1]+"_"+str(i)rhs.append(Var(rname,dtype))ifinitisNone:init=[]result=fcombine(lhs,rhs)id_elem=fidentity(*dtypes)else:assertisinstance(expr,tvm.ir.PrimExpr)size=1dtype=expr.dtypelvar=Var(code.co_varnames[0],dtype)rvar=Var(code.co_varnames[1],dtype)result=[fcombine(lvar,rvar)]id_elem=[fidentity(dtype)]lhs=[lvar]rhs=[rvar]expr=[expr]ifinitisnotNone:init=[init]combiner=CommReducer(lhs,rhs,result,id_elem)ifnotisinstance(axis,(list,tuple,tvm.ir.Array)):axis=[axis]ifwhereisNone:where=tir.convert(True)ifinitisNone:outputs=tuple(tvm.tir.Reduce(combiner,expr,axis,where,i,[])foriinrange(size))else:outputs=tuple(tvm.tir.Reduce(combiner,expr,axis,where,i,init)foriinrange(size))returnoutputs[0]ifsize==1elseoutputs# pylint: disable=keyword-arg-before-varargdefreducer(expr,axis,where=None,init=None,*args):ifisinstance(axis,(tvm.tir.IterVar,list,tuple)):assertnotargsreturn_make_reduce(expr,axis,where,init)ifwhereisNone:assertnotargsassertinitisNonereturn_reduce_directly(expr,axis)elifinitisNone:assertnotargsreturn_reduce_directly(expr,axis,where)else:return_reduce_directly(expr,axis,where,init,*args)doc_str="""Create a {0} expression over axis. Parameters ---------- expr : PrimExpr The source expression. axis : IterVar The reduction IterVar axis where : optional, Expr Filtering predicate of the reduction. Returns ------- value : PrimExpr The result value. Example ------- .. code-block:: python m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this {0} reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.{0} represents tvm.te.{0} or tvm.tir.{0}. B = te.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs:{0}_res = tvm.{0}(m, n) """reducer.__doc__=doc_str.format(name)returnreducer
[文档]defTVMBackendAllocWorkspace(device_type,device_id,nbytes,dtype_code_hint,dtype_bits_hint):"""Backend function to allocate temporal workspace Parameters ---------- device_type : int The device type which the space will be allocated. device_id : int The device id which the space will be allocated. nbytes : int The size of the space requested. dtype_code_hint : int The type code of the array elements. Only used in certain backends such as OpenGL. dtype_bits_hint : int The type bits of the array elements. Only used in certain backends such as OpenGL. Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.TVMBackendAllocWorkspace",device_type,device_id,nbytes,dtype_code_hint,dtype_bits_hint,)
[文档]defTVMBackendFreeWorkspace(device_type,device_id,ptr):"""Backend function to free temporal workspace. Parameters ---------- device_type : int The device type which the space will be allocated. device_id : int The device id which the space will be allocated. ptr : Var The result allocated space pointer. Returns ------- call : PrimExpr The call expression. """returncall_intrin("int32","tir.TVMBackendFreeWorkspace",device_type,device_id,ptr)
defanylist_getitem(list_handle,index):"""Returns an item from any list. list_handle: Var The handle to anylist index : int The index Returns ------- call : PrimExpr The call expression. """returncall_intrin("handle","tir.anylist_getitem",list_handle,index)defanylist_resetitem(list_handle,index):"""Reset an item from any list. list_handle: Var The handle to anylist index : int The index Returns ------- call : PrimExpr The call expression. """returncall_intrin("int","tir.anylist_resetitem",list_handle,index)defanylist_setitem_call_packed(list_handle,index,func_name,*args):"""Set anylist item by result of packed call. list_handle: Var The handle to anylist index : int The index func_name: str The name of the function to be called. args: Extra arguments Returns ------- call : PrimExpr The call expression. """returncall_intrin("int","tir.anylist_setitem_call_packed",list_handle,index,func_name,*args)defanylist_setitem_call_cpacked(list_handle,index,func_name,*args):"""Set anylist item by result of packed call. list_handle: Var The handle to anylist index : int The index func_name: str The name of the function to be called. args: Extra arguments Returns ------- call : PrimExpr The call expression. """returncall_intrin("int","tir.anylist_setitem_call_cpacked",list_handle,index,func_name,*args)
[文档]defvscale():"""Get the target's vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) Returns ------- call : PrimExpr Call to the vscale intrinsic """returncall_intrin("int32","tir.vscale")
[文档]defget_active_lane_mask(dtype,base,limit):""" Calculate a predicate mask given an upper bound (limit) and a current value (base). It will be lowered to the llvm.get.active.lane.mask intrinsic. (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics) Parameters ---------- dtype : str The data type of the result. base : PrimExpr An expression reprsenting the base. limit : PrimExpr An expression representing the limit. """returncall_intrin(dtype,"tir.get_active_lane_mask",base,limit)
[文档]defget_vscale_expr(dtype:Union[str,tvm.DataType],min_size:int=128)->PrimExpr:""" Create a datatype dependent scalable expression. Parameters ---------- dtype : Union[str, tvm.DataType] Element data type. min_size : int The minimum size of the scalable vector in bits. """ifisinstance(dtype,str):dtype=tvm.DataType(dtype)returnmin_size//dtype.bits*vscale()
[文档]defignore_loop_partition(predicate)->PrimExpr:""" Annotate a predicate not be considered as target condition of loop partition. Parameters ---------- predicate : PrimExpr The annotated predicate expression. """returncall_intrin("bool","tir.ignore_loop_partition",predicate)