tvm.tir#
Namespace for Tensor-level IR
- class tvm.tir.Add(a, b, span=None)[源代码]#
Add node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Allocate(buffer_var, dtype, extents, condition, body, annotations=None, span=None)[源代码]#
Allocate node.
Parameters#
- buffer_varVar
The buffer variable.
- dtypestr
The data type of the buffer.
- extentslist of Expr
The extents of the allocate
- conditionPrimExpr
The condition.
- bodyStmt
The body statement.
- annotations: Optional[Mapping[str, Object]]
Additional annotation hints
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations=None, span=None)[源代码]#
Allocate constant node.
Parameters#
- buffer_varVar
The buffer variable.
- dtypestr
The data type of the buffer.
- extentslist of Expr
The extents of the allocate
- data_or_idxUnion[NDArray, int]
If an NDArray, this is the const data associated with the constant. If an integer, this is the index into the "constants" attribute of the IRModule that contains the AllocateConst.
- bodyStmt
The body statement.
- annotationsOptional[Mapping[str, Object]]
Additional annotations about the allocation.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.And(a, b, span=None)[源代码]#
And node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Any(span=None)[源代码]#
Any node.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
span (Span | None)
- class tvm.tir.AssertStmt(condition, message, body, span=None)[源代码]#
AssertStmt node.
Parameters#
- conditionPrimExpr
The assert condition.
- messagePrimExpr
The error message.
- bodytvm.tir.Stmt
The body statement.
- spanOptional[Span]
The location of the stmt in the source code.
- 参数:
condition (PrimExpr)
message (PrimExpr)
body (Stmt)
span (Span | None)
- class tvm.tir.AttrStmt(node, attr_key, value, body, span=None)[源代码]#
AttrStmt node.
Parameters#
- nodeObject
The node to annotate the attribute
- attr_keystr
Attribute type key.
- valuePrimExpr
The value of the attribute
- bodyStmt
The body statement.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.BijectiveLayout[源代码]#
Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other.
Do not construct directly, use
bijective_layout
instead. See the documentation ofbijective_layout
for more details.Parameters#
- src_layoutstr or Layout
source layout.
- dst_layoutstr or Layout
destination layout.
See Also#
bijective_layout : Declare a layout
- backward_index(index)[源代码]#
Given the indices of the dst-layout, infer the src index.
Parameters#
- index: Array of Expr
The indices in dst-layout.
Returns#
- src_index: Array of Expr
The inferred indices in src-layout.
- backward_shape(shape)[源代码]#
Given the shape of the dst-layout, infer the src shape.
Parameters#
- shape: Array of Expr
The shape in dst-layout.
Returns#
- src_shape: Array of Expr
The inferred shape in src-layout.
- class tvm.tir.Block(iter_vars, reads, writes, name_hint, body, init=None, alloc_buffers=None, match_buffers=None, annotations=None, span=None)[源代码]#
Block node.
Parameters#
- iter_varsList[IterVar]
The block Variable.
- readsList[BufferRegion]
The read buffer regions of the block.
- writes: List[BufferRegion]
The write buffer regions of the block.
- name_hint: str
the name_hint of the block.
- body: Stmt
The body of the block.
- init: Optional[Stmt]
The init block of the reduction block
- alloc_buffers: Optional[list[Buffer]]
The buffer allocations
- match_buffers: Optional[List[MatchBufferRegion]]
The subregion buffer match
- annotations: Optional[Mapping[str, Object]]
Additional annotation hints.
- spanOptional[Span]
The location of this block in the source code.
- 参数:
reads (List[BufferRegion])
writes (List[BufferRegion])
name_hint (str)
body (Stmt)
init (Stmt | None)
match_buffers (List[MatchBufferRegion])
span (Span | None)
- class tvm.tir.BlockDependenceInfo(mod)[源代码]#
An object that helps build and query block level dependences using the 2 core objects BlockScope and StmtSRef
The data structures exposed are: 1) sref2scope: Mapping from the srefs to its corresponding BlockScope 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs
Note that this object does not store SRefs to loops as the purpose is only to expose block level dependences. This provides the advantage that the scope block (parent block) for a given block sref can be directly accessed as sref->parent
- 参数:
mod (IRModule)
- class tvm.tir.BlockRealize(iter_values, predicate, block, span=None)[源代码]#
BlockRealize node.
Parameters#
- iter_valuesList[PrimExpr]
The binding values of the block var.
- predicateUnion[PrimExpr, bool]
The predicate of the block.
- blockBlock
The block to realize
- spanOptional[Span]
The location of this block_realize in the source code.
- class tvm.tir.Broadcast(value, lanes, span=None)[源代码]#
Broadcast node.
Parameters#
- valuePrimExpr
The value of the expression.
- lanesPrimExpr
The lanes of the expression.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
value (PrimExpr)
lanes (PrimExpr)
span (Span | None)
- class tvm.tir.Buffer[源代码]#
Symbolic data buffer in TVM.
Buffer provide a way to represent data layout specialization of data structure in TVM.
Do not construct directly, use
decl_buffer()
instead. See the documentation ofdecl_buffer()
for more details.See Also#
decl_buffer : Declare a buffer
- access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)[源代码]#
Get an access pointer to the head of buffer.
This is the recommended method to get buffer data ptress when interacting with external functions.
Parameters#
- access_maskint
The access pattern MASK. Indicate whether the access will read or write to the data content.
- ptr_typestr, optional
The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.
- content_lanes: int, optional
The number of lanes for the data type. This value is greater than one for vector types.
- offset: Expr, optional
The offset of pointer. We can use it to offset by the number of elements from the address of ptr.
- extent: Expr, optional
The extent of pointer.
Examples#
# Get access ptr for read buffer.access_ptr("r") # Get access ptr for read/write with bitmask buffer.access_ptr(Buffer.READ | Buffer.WRITE) # Get access ptr for read/write with str flag buffer.access_ptr("rw") # Get access ptr for read with offset buffer.access_ptr("r", offset = 100) # Get access ptr for read with extent buffer.access_ptr("r", extent = 100)
- get_flattened_buffer()[源代码]#
Generate a Buffer that is a flattened version of this buffer.
Returns#
- flattenedBuffer
The corresponding flat buffer.
- offset_of(indices)[源代码]#
Determine the offset of the provided indices in the flattened buffer.
Parameters#
indices : Union[PrimExpr, List[PrimExpr]]
The indices of the element in the original buffer.
Returns#
flattened_indices: List[PrimExpr]
The offset indices of the element in the flattened buffer.
- scope()[源代码]#
Return the storage scope associated with this buffer. Returns ------- scope : str
The storage scope associated with this buffer.
- vload(begin, dtype=None, predicate=None)[源代码]#
Generate an Expr that loads dtype from begin index.
Parameters#
- beginArray of Expr
The beginning index in unit of Buffer.dtype
- dtypestr
The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype
- predicateOptional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
Returns#
- loadExpr
The corresponding load expression.
- vstore(begin, value, predicate=None)[源代码]#
Generate a Stmt that store value into begin index.
Parameters#
- beginArray of Expr
The beginning index in unit of Buffer.dtype
- valueExpr
The value to be stored.
- predicateOptional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value.
Returns#
- storeStmt
The corresponding store stmt.
- class tvm.tir.BufferLoad(buffer, indices, predicate=None, span=None)[源代码]#
Buffer load node.
Parameters#
- bufferBuffer
The buffer to be loaded.
- indicesList[PrimExpr]
The buffer indices to load values from.
- spanOptional[Span]
The location of this expression in the source code.
- predicateOptional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
- class tvm.tir.BufferRealize(buffer, bounds, condition, body, span=None)[源代码]#
Buffer realize node.
Parameters#
- bufferBuffer
The buffer.
- boundsList[Range]
The value we to be stored.
- conditionPrimExpr
The realize condition.
- bodyStmt
The body of the statement.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.BufferRegion(buffer, region)[源代码]#
BufferRegion node.
Parameters#
- bufferBuffer
The buffer of the buffer region
- regionList[Range]
The region array of the buffer region
- class tvm.tir.BufferStore(buffer, value, indices, predicate=None, span=None)[源代码]#
Buffer store node.
Parameters#
- bufferBuffer
The buffer.
- valuePrimExpr
The value we to be stored.
- indicesList[PrimExpr]
The indices location to be stored.
- predicateOptional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.Call(dtype, op, args, span=None)[源代码]#
Call node.
Parameters#
- dtypestr
The return data type
- opUnion[Op, str]
The function to be called, or the name to the global tvm.Op
- argslist of Expr
The input arguments to the call
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.Cast(dtype, value, span=None)[源代码]#
Cast expression.
Parameters#
- dtypestr
The data type
- valuePrimExpr
The value of the function.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
dtype (str)
value (PrimExpr)
span (Span | None)
- class tvm.tir.CommReducer(lhs, rhs, result, identity_element, span=None)[源代码]#
Commutative reduce operator
Parameters#
- lhsList[Var]
The left arguments of the reducer.
- rhsList[Var]
The right arguments of the reducer.
- resultList[PrimExpr]
The reduction results.
- identity_elementList[PrimExpr]
The identity elements.
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.DeclBuffer(buffer, body, span=None)[源代码]#
DeclBuffer node.
Parameters#
- buffer: Buffer
The buffer being declared.
- body: Stmt
The body statement to be executed.
- span: Optional[Span]
The location of this DeclBuffer in the source code.
- class tvm.tir.Div(a, b, span=None)[源代码]#
Div node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.EQ(a, b, span=None)[源代码]#
EQ node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Evaluate(value, span=None)[源代码]#
Evaluate node.
Parameters#
- valuePrimExpr
The expression to be evaluated.
- spanOptional[Span]
The location of the stmt in the source code.
- 参数:
value (PrimExpr)
span (Span | None)
- class tvm.tir.FloatImm(dtype, value, span=None)[源代码]#
Float constant.
Parameters#
- dtypestr
The data type
- valuefloat
The constant value.
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.FloorDiv(a, b, span=None)[源代码]#
FloorDiv node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.FloorMod(a, b, span=None)[源代码]#
FloorMod node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.For(loop_var, min, extent, kind, body, thread_binding=None, annotations=None, span=None)[源代码]#
For node.
Parameters#
- loop_varVar
The loop variable.
- minPrimExpr
The beginning value.
- extentPrimExpr
The length of the loop.
- kindForKind
The type of the for.
- bodyStmt
The body statement.
- thread_binding: Optional[tir.IterVar]
The thread this loop binds to. Only valid if kind is ThreadBinding
- annotations: Optional[Mapping[str, Object]]
Additional annotation hints.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.ForKind(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#
The kind of the for loop.
note#
ForKind can change the control flow semantics of the loop and need to be considered in all TIR passes.
- class tvm.tir.GE(a, b, span=None)[源代码]#
GE node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.GT(a, b, span=None)[源代码]#
GT node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.IfThenElse(condition, then_case, else_case, span=None)[源代码]#
IfThenElse node.
Parameters#
- conditionPrimExpr
The expression
- then_caseStmt
The statement to execute if condition is true.
- else_caseOptional[Stmt]
The statement to execute if condition is false.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.IndexMap(initial_indices, final_indices, inverse_index_map)[源代码]#
A mapping from multi-dimensional indices to another set of multi-dimensional indices
Parameters#
- initial_indicesList[Var]
Variables representing the indices prior to remapping.
- final_indicesList[PrimExpr]
Expressions defining the indices after remapping.
- inverse_index_mapUnion[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user's responsibility to ensure the correctness of the pre-defined inverse index map.
- static from_func(mapping_function, ndim=None, inverse_index_map=None, *, index_dtype='int64')[源代码]#
Create an index map from a function
Parameters#
mapping_function : Callable
The function to map from source indices to target indices. The function should accept tir.Var parameters and return a either a tir.PrimExpr, or a list of tir.PrimExpr. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.
ndim: Optional[int]
The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.
- inverse_index_mapUnion[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user's responsibility to ensure the correctness of the pre-defined inverse index map.
Returns#
index_map: IndexMap
Returns an IndexMap representing the mapping_function.
- static from_func_with_separators(mapping_function, ndim=None, inverse_index_map=None, *, index_dtype='int64')[源代码]#
Create an index map from a function
Parameters#
mapping_function : Callable
The function to map from source indices to target indices. The function should accept tir.Var parameters and return either a tir.PrimExpr or a list. Each element of the returned list should be either a tir.PrimExpr or the object IndexMap.AXIS_SEPARATOR. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.
ndim: Optional[int]
The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.
- inverse_index_mapUnion[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user's responsibility to ensure the correctness of the pre-defined inverse index map.
- index_dtypestr
The default index dtype to use for input iters in the mapping function.
Returns#
ret: Tuple[IndexMap, List[int]]
Returns a tuple whose first element is an IndexMap representing the mapping_function, and whose second index is a list of indices at which IndexMap.AXIS_SEPARATOR occurred.
- inverse(shape)[源代码]#
Return the inverse of the map
Throws an error if the function is not bijective.
Parameters#
shape: List[Union[Range,PrimExpr]]
The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range.
Returns#
inverse : IndexMap
The inverse
- is_equivalent_to(other_map)[源代码]#
Return if the index maps are equivalent.
Parameters#
other_map: IndexMap
The IndexMap to which the comparison should be made.
Returns#
is_equivalent: bool
True if the two mappings represent the same transformation, otherwise False
- map_indices(indices)[源代码]#
Apply the index map to a set of indices
Parameters#
- indicesList[PrimExpr]
The indices to be mapped
Returns#
- resultList[PrimExpr]
The mapped indices
- map_ndarray(arr_src)[源代码]#
Apply thie index map to transform the layout of the input NDArray
Parameters#
- arr_srcruntime.NDArray
The NDArray to be transformed
Returns#
- arr_dstruntime.NDArray
The transformed NDArray
- map_shape(shape)[源代码]#
Apply the index map to a buffer shape
Parameters#
- shapeList[PrimExpr]
The buffer shape to be mapped
Returns#
- resultList[PrimExpr]
The mapped shape
- non_surjective_inverse(shape)[源代码]#
Return the inverse of the map
Can be applied to transformations that introduce padding.
Parameters#
shape: List[Union[Range,PrimExpr]]
The region over which the inverse should be determined. Used for determining the predicate.
Returns#
result : Tuple[IndexMap, PrimExpr]
The inverse, and a predicate for which the inverse maps to a valid index in the input range.
Examples#
index_map = IndexMap.from_func(lambda i: [i//4, i%4]) inverse_map, predicate = index_map.non_surjective_inverse([14]) assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k]) print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
- class tvm.tir.IntImm(dtype, value, span=None)[源代码]#
Int constant.
Parameters#
- dtypestr
The data type
- valueint
The constant value.
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.IterVar(dom, var, iter_type, thread_tag='', span=None)[源代码]#
Represent iteration variable.
IterVar represents axis iterations in the computation.
Parameters#
- domRange
The domain of the iteration.
- varUnion[Var, str]
The internal variable that is used for iteration.
- iter_typeint
The iteration type.
- thread_tagstr
The thread type tag.
- spanOptional[Span]
The location of this expression in the source code.
See Also#
te.thread_axis: Create thread axis IterVar. te.reduce_axis: Create reduce axis IterVar.
- class tvm.tir.LE(a, b, span=None)[源代码]#
LE node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.LT(a, b, span=None)[源代码]#
LT node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Layout[源代码]#
Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
See Also#
layout : Declare a layout
- factor_of(axis)[源代码]#
Get the factor size of the subordinate axis.
Parameters#
- axisstr
The axis name, need to be [a-z,A-Z]
Returns#
- factorint
the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout.
- class tvm.tir.Let(var, value, body, span=None)[源代码]#
Let node.
Parameters#
- varVar
The variable in the binding.
- valuePrimExpr
The value in to be bound.
- bodyPrimExpr
The body expression.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
var (Var)
value (PrimExpr)
body (PrimExpr)
span (Span | None)
- class tvm.tir.LetStmt(var, value, body, span=None)[源代码]#
LetStmt node.
Parameters#
- varVar
The variable in the binding.
- valuePrimExpr
The value in to be bound.
- bodyStmt
The body statement.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.MatchBufferRegion(buffer, source)[源代码]#
MatchBufferRegion node.
Parameters#
- bufferBuffer
The target buffer
- sourceBufferRegion
The region of source buffer
- 参数:
buffer (Buffer)
source (BufferRegion)
- class tvm.tir.Max(a, b, span=None)[源代码]#
Max node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Min(a, b, span=None)[源代码]#
Min node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Mod(a, b, span=None)[源代码]#
Mod node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Mul(a, b, span=None)[源代码]#
Mul node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.NE(a, b, span=None)[源代码]#
NE node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Not(a, span=None)[源代码]#
Not node.
Parameters#
- aPrimExpr
The input value
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
span (Span | None)
- class tvm.tir.Or(a, b, span=None)[源代码]#
Or node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.Prefetch(buffer, bounds, span=None)[源代码]#
Prefetch node.
Parameters#
- bufferBuffer
The buffer to be prefetched.
- boundsList[Range]
The bounds to be prefetched.
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)[源代码]#
A function declaration expression.
Parameters#
- params: List[Union[tvm.tir.Var, tvm.tir.Buffer]]
List of input parameters to the function.
- body: tvm.tir.Stmt
The body of the function.
- ret_type: tvm.ir.Type
The return type annotation of the function.
- buffer_mapMap[tvm.tir.Var, tvm.tir.Buffer]
The buffer binding map.
- attrs: Optional[tvm.Attrs]
Attributes of the function, can be None
- spanOptional[Span]
The location of this itervar in the source code.
- specialize(param_map)[源代码]#
Specialize parameters of PrimFunc
Parameters#
- param_mapMapping[Var, Union[PrimExpr, Buffer]]
The mapping from function params to the instance
Examples#
We can define a Meta TIR function with symbolic shape:
@T.prim_func def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: A = T.match_buffer(a, (m, n), "float32") B = T.match_buffer(b, (m, n), "float32") for i, j in T.grid(m, n): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj]
Then we can make it specialized with given shapes or buffers.
a, _, m, n = mem_copy.params func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) # or func = mem_copy.specialize({n: 16, m: 16})
The specialized function:
@T.prim_func def mem_copy_16_16(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") for i, j in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj]
Returns#
- funcPrimFunc
The new function with parameter specialized
- 参数:
span (Span | None)
- class tvm.tir.ProducerLoad(producer, indices, span=None)[源代码]#
Producer load node.
Parameters#
- producerDataProducer
The buffer to be loaded.
- indicesList[PrimExpr]
The buffer indices.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
producer (DataProducer)
indices (List[PrimExpr])
span (Span | None)
- class tvm.tir.ProducerRealize(producer, bounds, condition, body, storage_scope='', span=None)[源代码]#
ProducerRealize node.
Parameters#
- producerDataProducer
The data producer.
- boundsList[Range]
The bound of realize
- conditionPrimExpr
The realize condition.
- bodyStmt
The realize body
- storage_scopestr
The storage scope associated with this realization
- spanOptional[Span]
The location of the stmt in the source code.
- 参数:
producer (DataProducer)
bounds (List[Range])
condition (PrimExpr)
body (Stmt)
storage_scope (str)
span (Span | None)
- class tvm.tir.ProducerStore(producer, value, indices, span=None)[源代码]#
ProducerStore node.
Parameters#
- producerDataProducer
The data producer.
- valuePrimExpr
The value to be stored.
- indiceslist of Expr
The index arguments of the store.
- spanOptional[Span]
The location of the stmt in the source code.
- 参数:
producer (DataProducer)
value (PrimExpr)
indices (List[PrimExpr])
span (Span | None)
- class tvm.tir.Ramp(base, stride, lanes, span=None)[源代码]#
Ramp node.
Parameters#
- basePrimExpr
The base expression.
- stridePrimExpr
The stride of the ramp.
- lanesPrimExpr
The lanes of the expression.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
base (PrimExpr)
stride (PrimExpr)
lanes (PrimExpr)
span (Span | None)
- class tvm.tir.Reduce(combiner, src, rdom, condition, value_index, init=None, span=None)[源代码]#
Reduce node.
Parameters#
- combinerCommReducer
The combiner.
- srclist of Expr
The source expression.
- rdomlist of IterVar
The iteration domain
- conditionPrimExpr
The reduce condition.
- value_indexint
The value index.
- initlist of Expr
The initial value for output. This can be an int, float or ProducerLoad
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.Select(condition, true_value, false_value, span=None)[源代码]#
Select node.
Note#
Select may compute both true_value and false_value. Use
tvm.tir.if_then_else
instead if you want to get a conditional expression that only evaluates the correct branch.Parameters#
- conditionPrimExpr
The condition expression.
- true_valuePrimExpr
The value to take when condition is true.
- false_valuePrimExpr
The value to take when condition is false.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
condition (PrimExpr)
true_value (PrimExpr)
false_value (PrimExpr)
span (Span | None)
- class tvm.tir.SeqStmt(seq, span=None)[源代码]#
Sequence of statements.
Parameters#
- seqList[Stmt]
The statements
- spanOptional[Span]
The location of the stmt in the source code.
- class tvm.tir.Shuffle(vectors, indices, span=None)[源代码]#
Shuffle node.
Parameters#
- vectorsList[PrimExpr]
The vectors
- indicesList[PrimExpr]
The indices
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.SizeVar(name, dtype, span=None)[源代码]#
- Symbolic variable to represent a tensor index size
which is greater or equal to zero.
Parameters#
- namestr
The name
- dtypeUnion[str, ir.Type]
The data type
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.StringImm(value, span=None)[源代码]#
String constant.
Parameters#
- valuestr
The value of the function.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
value (str)
span (Span | None)
- class tvm.tir.Sub(a, b, span=None)[源代码]#
Sub node.
Parameters#
- aPrimExpr
The left hand operand.
- bPrimExpr
The right hand operand.
- spanOptional[Span]
The location of this expression in the source code.
- 参数:
a (PrimExpr)
b (PrimExpr)
span (Span | None)
- class tvm.tir.TensorIntrin(desc, impl)[源代码]#
A tensor intrinsic.
Parameters#
- descPrimFunc
The function to describe the computation.
- implPrimFunc
The function of the implementation for the execution.
- static get(name, allow_missing=False)[源代码]#
Look up a tensor intrinsic by its name.
Parameters#
- namestr
The name of the TensorIntrin to look up.
- allow_missingbool
Whether to allow missing tensor intrin. If False, raise an error if the tensor intrin
doesn't exist.
Returns#
- resultOptional[TensorIntrin]
The TensorIntrin with the specified name, or None if not found.
- 参数:
- 返回类型:
TensorIntrin | None
- static register(name, desc, impl, override=False)[源代码]#
Register a tensor intrinsic with its name.
Parameters#
- namestr
The name of the TensorIntrin to register.
- descPrimFunc
The function to describe the computation.
- implPrimFunc
The function of the implementation for the execution.
- override: bool
Whether override existing intrinsic.
- class tvm.tir.Var(name, dtype, span=None)[源代码]#
Symbolic variable.
Parameters#
- namestr
The name
- dtypeUnion[str, ir.Type]
The data type
- spanOptional[Span]
The location of this expression in the source code.
- class tvm.tir.While(condition, body, span=None)[源代码]#
While node.
Parameters#
- conditionPrimExpr
The termination condition.
- bodyStmt
The body statement.
- spanOptional[Span]
The location of the stmt in the source code.
- 参数:
condition (PrimExpr)
body (Stmt)
span (Span | None)
- tvm.tir.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)[源代码]#
Backend function to allocate temporal workspace
Parameters#
- device_typeint
The device type which the space will be allocated.
- device_idint
The device id which the space will be allocated.
- nbytesint
The size of the space requested.
- dtype_code_hintint
The type code of the array elements. Only used in certain backends such as OpenGL.
- dtype_bits_hintint
The type bits of the array elements. Only used in certain backends such as OpenGL.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.TVMBackendFreeWorkspace(device_type, device_id, ptr)[源代码]#
Backend function to free temporal workspace.
Parameters#
- device_typeint
The device type which the space will be allocated.
- device_idint
The device id which the space will be allocated.
- ptrVar
The result allocated space pointer.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.abs(x, span=None)[源代码]#
Get absolute value of the input element-wise.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.acos(x)[源代码]#
Take acos of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.acosh(x)[源代码]#
Take acos of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.add(lhs, rhs, span=None)[源代码]#
Generic add operator.
Parameters#
- lhsobject
The left operand.
- rhsobject
The right operand.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- optvm.Expr
The result Expr of add operaton.
- tvm.tir.address_of(buffer_load, span=None)[源代码]#
Returns the address of an element in the buffer
Parameters#
- buffer_load: BufferLoad
The buffer load.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.all(*args, span=None)[源代码]#
- Create a new expression of the intersection of all conditions in the
arguments
Parameters#
- argslist
List of symbolic boolean expressions
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- expr: Expr
Expression
- tvm.tir.any(*args, span=None)[源代码]#
Create a new experssion of the union of all conditions in the arguments
Parameters#
- argslist
List of symbolic boolean expressions
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- expr: Expr
Expression
- tvm.tir.asin(x)[源代码]#
Take asin of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.asinh(x)[源代码]#
Take asinh of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.assume(cond=None)[源代码]#
Provide a true statement that can be used for simplifications
Parameters#
- condExpr
The constraint condition.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.atan(x)[源代码]#
Take atan of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.atan2(x1, x2)[源代码]#
Take arctan2(x1, x2).
Parameters#
- x1PrimExpr
Input argument.
- x2PrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.atanh(x)[源代码]#
Take atanh of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.bijective_layout(src_layout, dst_layout)[源代码]#
Create a bijective layout mapping.
Parameters#
- src_layoutstr or Layout
source layout.
- dst_layoutstr or Layout
destination layout.
Returns#
- bijective_layoutBijectiveLayout
The created bijective layout
- 参数:
- 返回类型:
- tvm.tir.bitwise_and(x, y, span=None)[源代码]#
Take bitwise and of two values
Parameters#
- xPrimExpr
Left operand
- yPrimExpr
Right operand
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- resPrimExpr
The result.
- tvm.tir.bitwise_not(x, span=None)[源代码]#
Take bitwise not of input value
Parameters#
- xPrimExpr
Input operand
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- resPrimExpr
The result.
- tvm.tir.bitwise_or(x, y, span=None)[源代码]#
Take bitwise or of two values
Parameters#
- xPrimExpr
Left operand
- yPrimExpr
Right operand
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- resPrimExpr
The result.
- tvm.tir.bitwise_xor(x, y, span=None)[源代码]#
Take bitwise xor of two values
Parameters#
- xPrimExpr
Left operand
- yPrimExpr
Right operand
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- resPrimExpr
The result.
- tvm.tir.call_cpacked(*args, span=None)[源代码]#
Build expression by call an external packed function.
Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.
Parameters#
- argslist of Expr or Buffer.
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
See Also#
te.extern : Create tensor with extern function call.
- tvm.tir.call_cpacked_lowered(*args, span=None)[源代码]#
Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.
Parameters#
- argslist of Expr or Buffer.
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
See Also#
te.extern : Create tensor with extern function call.
- tvm.tir.call_extern(dtype, func_name, *args, span=None)[源代码]#
Build expression by calling a extern function.
Parameters#
- dtypestr
The data type of the result.
- func_name: str
The extern function name.
- argslist
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.call_intrin(dtype, func_name, *args, span=None)[源代码]#
Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.
Parameters#
- dtypestr
The data type of the result.
- func_name: str
The intrinsic function name.
- argslist
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.call_llvm_intrin(dtype, name, *args, span=None)[源代码]#
Build expression by calling a llvm intrinsic function
Parameters#
- dtypestr
The data type of the result.
- namestr
The name of the llvm intrinsic function.
- argslist
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.call_llvm_pure_intrin(dtype, name, *args, span=None)[源代码]#
Build expression by calling a pure llvm intrinsic function
Parameters#
- dtypestr
The data type of the result.
- namestr
The name of the llvm intrinsic function.
- argslist
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.call_packed(*args, span=None)[源代码]#
Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters#
- argslist of Expr or Buffer.
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
See Also#
te.extern : Create tensor with extern function call.
- tvm.tir.call_packed_lowered(*args, span=None)[源代码]#
Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters#
- argslist of Expr or Buffer.
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
See Also#
te.extern : Create tensor with extern function call.
- tvm.tir.call_pure_extern(dtype, func_name, *args, span=None)[源代码]#
Build expression by calling a pure extern function.
Parameters#
- dtypestr
The data type of the result.
- func_name: str
The extern function name.
- argslist
Positional arguments.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.call_tir(global_var, *args)[源代码]#
Performs a call into another PrimFunc in the same IRModule
Returns#
- callPrimExpr
The call expression.
- 参数:
global_var (GlobalVar)
- tvm.tir.ceil(x, span=None)[源代码]#
Take ceil of float input x.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.ceildiv(lhs, rhs, span=None)[源代码]#
Generic ceildiv operator.
Parameters#
- lhsobject
The left operand.
- rhsobject
The right operand.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- optvm.Expr
The result Expr of ceildiv operaton.
- tvm.tir.clz(x)[源代码]#
Count leading zero bits of an integer x.
Parameters#
- xPrimExpr
Input 32 or 64 bit integer. The result is undefined if the input is 0.
Returns#
- yPrimExpr
The result.
- tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')[源代码]#
Create a commutative reducer for reduction.
Parameters#
- fcombinefunction(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
- fidentityfunction(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns#
- reducerfunction
A function which creates a reduce expression over axis. There are two ways to use it:
accept (expr, axis, where) to produce an Reduce Expr on specified axis;
simply use it with multiple Exprs.
Example#
n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
- tvm.tir.copysign(x1, x2)[源代码]#
Change the sign of x1 to that of x2, element-wise.
Parameters#
- x1PrimExpr
Input argument.
- x2PrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.cos(x)[源代码]#
Take cos of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.cosh(x)[源代码]#
Take cosh of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.create_barriers(barrier_count)[源代码]#
TVM intrinsic to create N barriers
Parameters#
- barrier_countint
The number of barriers to create.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=-1, offset_factor=0, buffer_type='', axis_separators=None, span=None)[源代码]#
Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout.
See the note below for detailed discussion on usage of buffer.
Parameters#
- shapetuple of Expr
The shape of the buffer.
- dtypestr, optional
The data type of the buffer.
- namestr, optional
The name of the buffer.
- datatir.Var, optional
The data pointer in the buffer.
- strides: array of Expr
The stride of the buffer.
- elem_offset: Expr, optional
The beginning offset of the array to data. In terms of number of elements of dtype.
- scope: str, optional
The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory.
- data_alignment: int, optional
The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM's internal default.
- offset_factor: int, optional
The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
- buffer_type: str, optional, {"", "auto_broadcast"}
auto_broadcast buffer allows one to implement broadcast computation without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
- axis_separatorslist of int, optional
If passed, a list of separators between groups of axes, each of which is flattened to an output axis. For flat memory spaces, should either be None, or an empty list.
- span: Optional[Span]
The location of the decl_buffer creation in the source.
Returns#
- buffertvm.tir.Buffer
The created buffer
Example#
Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2") n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2") o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2") A = te.placeholder((m0, m1, m2), name='A') B = te.placeholder((n0, n1, n2), name='B') C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") s = te.create_schedule(C.op) fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) dev = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), dev) c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), dev) fadd(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())
Note#
Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it.
If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic.
- tvm.tir.div(a, b, span=None)[源代码]#
Compute a / b as in C/C++ semantics.
Parameters#
- aPrimExpr
The left hand operand, known to be non-negative.
- bPrimExpr
The right hand operand, known to be non-negative.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
Note#
When operands are integers, returns truncdiv(a, b, span).
- tvm.tir.dp4a(vec1, vec2, acc=0)[源代码]#
Dot product of two int8x4 vectors and add an optional accumulator
Parameters#
- vec1int8x4
The input vector.
- vec2int8x4
The input vector.
- accint32
The accumulator.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.end_profile_intrinsic(id)[源代码]#
End profile intrinsic. Parameters ---------- id : int
The intrinsic id.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.erf(x)[源代码]#
Take gauss error function of the input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.exp(x)[源代码]#
Take exponential of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.exp10(x)[源代码]#
Calculate 10**x
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.exp2(x)[源代码]#
Calculate 2**x
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.floor(x, span=None)[源代码]#
Take floor of float input x.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- 参数:
x (PrimExprWithOp)
- tvm.tir.floordiv(a, b, span=None)[源代码]#
Compute the floordiv of two expressions.
Parameters#
- aPrimExpr
The left hand operand
- bPrimExpr
The right hand operand
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
- tvm.tir.floormod(a, b, span=None)[源代码]#
Compute the floormod of two expressions.
Parameters#
- aPrimExpr
The left hand operand
- bPrimExpr
The right hand operand
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
- tvm.tir.fmod(x, y)[源代码]#
Return the remainder of x divided by y with the same sign as x.
Parameters#
- xPrimExpr
Input argument.
- yPrimExpr
Input argument.
Returns#
- zPrimExpr
The result.
- tvm.tir.get_active_lane_mask(dtype, base, limit)[源代码]#
Calculate a predicate mask given an upper bound (limit) and a current value (base).
It will be lowered to the llvm.get.active.lane.mask intrinsic. (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
Parameters#
- dtypestr
The data type of the result.
- basePrimExpr
An expression reprsenting the base.
- limitPrimExpr
An expression representing the limit.
- tvm.tir.get_vscale_expr(dtype, min_size=128)[源代码]#
Create a datatype dependent scalable expression.
Parameters#
- dtypeUnion[str, tvm.DataType]
Element data type.
- min_sizeint
The minimum size of the scalable vector in bits.
- tvm.tir.hypot(x1, x2)[源代码]#
Equivalent to sqrt(x1**2 + x2**2), element-wise.
Parameters#
- x1PrimExpr
Input argument.
- x2PrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.if_then_else(cond, t, f, span=None)[源代码]#
Conditional selection expression.
Parameters#
- condPrimExpr
The condition
- tPrimExpr
The result expression if cond is true.
- fPrimExpr
The result expression if cond is false.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resultNode
The result of conditional expression.
Note#
Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.
- tvm.tir.indexdiv(a, b, span=None)[源代码]#
Compute floor(a / b) where a and b are non-negative.
Parameters#
- aPrimExpr
The left hand operand, known to be non-negative.
- bPrimExpr
The right hand operand, known to be non-negative.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
Note#
Use this function to split non-negative indices. This function may take advantage of operands' non-negativeness.
- tvm.tir.indexmod(a, b, span=None)[源代码]#
Compute the remainder of indexdiv. a and b are non-negative.
Parameters#
- aPrimExpr
The left hand operand, known to be non-negative.
- bPrimExpr
The right hand operand, known to be non-negative.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
Note#
Use this function to split non-negative indices. This function may take advantage of operands' non-negativeness.
- tvm.tir.infinity(dtype, span=None)[源代码]#
infinity value of dtype
Parameters#
- dtypestr
The data type.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- valuetvm.Expr
The infinity value of dtype.
- tvm.tir.isfinite(x, span=None)[源代码]#
Check if input value is finite.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.isinf(x, span=None)[源代码]#
Check if input value is infinite.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.isnan(x, span=None)[源代码]#
Check if input value is Nan.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.isnullptr(x, span=None)[源代码]#
Check if input value is nullptr.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.layout(layout_str, dtype='int32')[源代码]#
Create a layout node from a string.
Parameters#
- layout_strstr
A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
- dtypestr
The dtype of generated axes vars in the returned layout. It is required to be integer type.
Returns#
- layoutLayout
The created layout
- tvm.tir.ldexp(x1, x2)[源代码]#
Returns x1 * (2 ** x2).
Parameters#
- x1PrimExpr
Input argument.
- x2PrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.likely(cond, span=None)[源代码]#
Mark condition as likely.
Parameters#
- condPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The marked expression.
- tvm.tir.log(x)[源代码]#
Take log of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.log10(x)[源代码]#
Take log10 of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.log1p(x)[源代码]#
Take log(x + 1) with respect to input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.log2(x)[源代码]#
Take log2 of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.lookup_param(param_name, span=None)[源代码]#
Returns the param by name
Parameters#
- param_namestr
The name of param.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.make_filled_simdgroup_matrix(d, index, value, col=8, row=8)[源代码]#
Create a filled SIMDGroup matrix
Parameters#
- dvar
The simdgroup var
- indexPrimExpr
The index of the matrix.
- valuePrimExpr
The value to fill.
- colint
The number of columns.
- rowint
The number of rows.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.max(expr, axis, where=None, init=None, *args)#
Create a max expression over axis.
Parameters#
- exprPrimExpr
The source expression.
- axisIterVar
The reduction IterVar axis
- whereoptional, Expr
Filtering predicate of the reduction.
Returns#
- valuePrimExpr
The result value.
Example#
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this max reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.max represents tvm.te.max or tvm.tir.max. B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: max_res = tvm.max(m, n)
- tvm.tir.max_value(dtype, span=None)[源代码]#
maximum value of dtype
Parameters#
- dtypestr
The data type.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- valuetvm.Expr
The maximum value of dtype.
- tvm.tir.min(expr, axis, where=None, init=None, *args)#
Create a min expression over axis.
Parameters#
- exprPrimExpr
The source expression.
- axisIterVar
The reduction IterVar axis
- whereoptional, Expr
Filtering predicate of the reduction.
Returns#
- valuePrimExpr
The result value.
Example#
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this min reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.min represents tvm.te.min or tvm.tir.min. B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: min_res = tvm.min(m, n)
- tvm.tir.min_value(dtype, span=None)[源代码]#
minimum value of dtype
Parameters#
- dtypestr
The data type.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- valuetvm.Expr
The minimum value of dtype.
- tvm.tir.mma_fill(dtype, local_size, local_ptr, offset)[源代码]#
TVM intrinsic for zero-initalizing an MMA accumulation registor
Parameters#
- dtypestr
The data type of the result.
- local_sizeIntImm
The number of elements.
- local_ptrVar
The destination pointer variable.
- offsetExpr
The destination offset.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)[源代码]#
TVM intrinsic for storing the result of PTX MMA into a destination pointer
Parameters#
- dtypestr
The data type of the result.
- mIntImm
The shape of mma fragment.
- nIntImm
The shape of mma fragment.
- dst_ptrVar
The destination pointer variable.
- src_ptrVar
The source pointer variable.
- src_offsetExpr
The source offset.
- dst_strideVar
The destination stride.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.multiply(lhs, rhs, span=None)[源代码]#
Generic multiply operator.
Parameters#
- lhsobject
The left operand.
- rhsobject
The right operand.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- optvm.Expr
The result Expr of multiply operaton.
- tvm.tir.nearbyint(x, span=None)[源代码]#
Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.nextafter(x1, x2)[源代码]#
Return the next floating-point value after x1 towards x2.
Parameters#
- x1PrimExpr
Input argument.
- x2PrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.popcount(x)[源代码]#
Count the number of set bits in input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.pow(x, y, span=None)[源代码]#
x power y
Parameters#
- xPrimExpr
Input argument.
- yPrimExpr
The exponent
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- zPrimExpr
The result.
- tvm.tir.power(x, y, span=None)[源代码]#
x power y
Parameters#
- xPrimExpr
Input argument.
- yPrimExpr
The exponent
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- zPrimExpr
The result.
- tvm.tir.ptx_arrive_barrier(barrier_id)[源代码]#
TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Parameters#
- barrier_idint
The ID of the barrier shared memory pointer.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)[源代码]#
TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
Parameters#
- barrier_idint
The ID of the barrier shared memory pointer.
- byte_countint
Increases the tx count of the mbarrier object to track completion of addtional async transactions.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_commit_group()[源代码]#
TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)[源代码]#
TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
Parameters#
- dtypestr
The data type of the result.
- shared_ptrVar
The shared memory pointer variable.
- shared_offsetExpr
The offset of shared memory pointer.
- global_ptrVar
The global memory pointer variable.
- global_offsetExpr
The offset of global memory pointer.
- bytesint
The data size to copy.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_cp_async_barrier(barrier_id)[源代码]#
TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
Parameters#
- barrier_idint
The ID of the barrier shared memory pointer.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)[源代码]#
TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
Parameters#
- dtypestr
The data type of the result.
- shared_ptrVar
The shared memory pointer variable.
- shared_offsetExpr
The offset of shared memory pointer.
- global_ptrVar
The global memory pointer variable.
- global_offsetExpr
The offset of global memory pointer.
- bytesint
The data size to copy.
- barrier_idint
The ID of the barrier shared memory pointer.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_init_barrier_thread_count(barrier_id, thread_count)[源代码]#
TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
Parameters#
- barrier_idint
The ID of the barrier shared memory pointer.
- thread_countint
Number of threads expected to arrive at the barrier.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)[源代码]#
TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
Parameters#
- dtypestr
The data type of the result.
- transbool
The matrix is loaded in column-major format.
- numIntImm
The number of matrices.
- typeLiteral[".b16"]
The data type of the matrices.
- local_ptrVar
The local pointer variable.
- local_offsetExpr
The offset of local pointer.
- smem_ptrVar
The shared memory pointer variable.
- smem_offsetExpr
The offset of shared memort pointer.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)[源代码]#
TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters#
- dtypestr
The data type of the result.
- shapestr
The shape of mma fragment.
- A_layoutLiteral["row", "col"]
The layout of multiplicand fragment A.
- B_layoutLiteral["row", "col"]
The layout of multiplicand fragment B.
- A_dtypestr
The data type of multiplicand fragment A.
- B_dtypestr
The data type of multiplicand fragment B.
- C_dtypestr
The data type of accumulator fragment C.
- multiplicand_aVar
The multiplicand fragment A variable.
- a_indexExpr
The index of multiplicand fragment A.
- multiplicand_bVar
The multiplicand fragment B variable.
- b_indexExpr
The index of multiplicand fragment A.
- accumulatorVar
The accumulator fragment C variable.
- c_indexExpr
The index of accumulator fragment C.
- saturatebool
The optional saturation at the output.
- operatorOptional[Literal["xor", "and"]]
The 1-bit operator.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)[源代码]#
TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
Parameters#
- dtypestr
The data type of the result.
- shapestr
The shape of mma fragment.
- A_layoutLiteral["row", "col"]
The layout of multiplicand fragment A.
- B_layoutLiteral["row", "col"]
The layout of multiplicand fragment B.
- A_dtypestr
The data type of multiplicand fragment A.
- B_dtypestr
The data type of multiplicand fragment B.
- C_dtypestr
The data type of multiplicand fragment C.
- multiplicand_aVar
The multiplicand fragment A variable.
- a_indexExpr
The index of multiplicand fragment A.
- multiplicand_bVar
The multiplicand fragment B variable.
- b_indexExpr
The index of multiplicand fragment B.
- accumulatorVar
The accumulator fragment C variable.
- c_indexExpr
The index of accumulator fragment C.
- metadataExpr
The metadata of operand.
- meta_indexExpr
The metadata index of operand.
- sparse_selectorExpr
The sparse selector indicating the thread that stores the metadata.
- saturatebool
The optional saturation at the output.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_wait_barrier(barrier_id)[源代码]#
TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
Parameters#
- barrier_idint
The ID of the barrier shared memory pointer.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.ptx_wait_group(num)[源代码]#
TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group
Parameters#
- numint
The number of the most recent uncommitted pending cp.async groups to wait.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.q_multiply_shift(x, y, q, s)[源代码]#
Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:
out = round(x*y*2^-s)
More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)
Parameters#
- xPrimExpr
First Q-number
- yPrimExpr
Second Q-number
- qPrimExpr
Number of fractional bits in x and y. Needs to be > 0
- sPrimExpr
Integer shift
Returns#
- yPrimExpr
The result.
- tvm.tir.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required)[源代码]#
Execute a multiplication between two Q-numbers x and y
Parameters#
- xPrimExpr
First Q-number.
- yPrimExpr
Second Q-number.
- lsPrimExpr
Integer left shift.
- rsPrimExpr
Integer right shift.
- qIntImm
Number of fractional bits in x and y. Needs to be > 0.
- is_lshift_requiredIntImm
Whether we need to do left shift or not.
- is_rshift_requiredIntImm
Whether we need to do right shift or not.
Returns#
- zPrimExpr
The result.
- tvm.tir.reinterpret(dtype, value, span=None)[源代码]#
infinity value of dtype
Parameters#
- dtypestr
The data type.
- valuePrimExpr
The input value.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- valuetvm.Expr
The reinterpret cast value of dtype.
- 参数:
span (Span | None)
- 返回类型:
- tvm.tir.ret(val, span=None)[源代码]#
Create a tir return expression
Parameters#
- valExpr
The returned tir expression, whose data type is int, float or void pointer.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- retPrimExpr
The return expression
- tvm.tir.round(x, span=None)[源代码]#
Round elements of the array to the nearest integer.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.rsqrt(x)[源代码]#
Take reciprocal of square root of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.shift_left(x, y, span=None)[源代码]#
Return the result of x left shifted by y bits.
Parameters#
- xPrimExpr
Input argument.
- yPrimExpr
Input argument.
Returns#
- zPrimExpr
The result.
- tvm.tir.shift_right(x, y, span=None)[源代码]#
Return the result of x right shifted by y bits.
Parameters#
- xPrimExpr
Input argument.
- yPrimExpr
Input argument.
Returns#
- zPrimExpr
The result.
- tvm.tir.sigmoid(x)[源代码]#
Quick function to get sigmoid
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.simdgroup_load(d, index, ptr, stride, col=8, row=8, transpose_matrix=False)[源代码]#
Load data from device memory or threadgroup memory to simdgroup
Parameters#
- dvar
The simdgroup var
- indexPrimExpr
The index of the matrix.
- ptrPrimExpr
The pointer.
- stridePrimExpr
The stride.
- colint
The number of columns.
- rowint
The number of rows.
- transpose_matrixbool
Whether to transpose the matrix.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.simdgroup_multiply_accumulate(d, index_d, a, index_a, b, index_b, c, index_c)[源代码]#
Multiply and accumulate two matrices in simdgroup i.e. d = a * b + c
Parameters#
- dVar
The destination matrix.
- index_dPrimExpr
The index of the destination matrix.
- aVar
The first matrix.
- index_aPrimExpr
The index of the first matrix.
- bVar
The second matrix.
- index_bPrimExpr
The index of the second matrix.
- cVar
The third matrix.
- index_cPrimExpr
The index of the third matrix.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.simdgroup_store(d, index, ptr, stride, col=8, row=8, transpose_matrix=False)[源代码]#
Store data from simdgroup to device memory or threadgroup memory
Parameters#
- dPrimExpr
The SIMDGroup.
- indexPrimExpr
The index of the matrix.
- ptrPrimExpr
The pointer.
- stridePrimExpr
The stride.
- colint
The number of columns.
- rowint
The number of rows.
- transpose_matrixbool
Whether to transpose the matrix.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.sin(x)[源代码]#
Take sin of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.sinh(x)[源代码]#
Take sinh of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.sqrt(x)[源代码]#
Take square root of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.start_profile_intrinsic(id)[源代码]#
Start profile intrinsic. Parameters ---------- id : int
The intrinsic id.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.stmt_list(stmt)[源代码]#
Make list of stmt from blocks.
Parameters#
- stmtStmt
The input statement.
Returns#
- stmt_listList[Stmt]
The unpacked list of statements
- tvm.tir.stmt_seq(*args)[源代码]#
Make sequence of statements
Parameters#
- *argsUnion[PrimExpr, Stmt]
List of statements to be combined as sequence.
Returns#
- stmtStmt
The combined statement.
- tvm.tir.subtract(lhs, rhs, span=None)[源代码]#
Generic subtract operator.
Parameters#
- lhsobject
The left operand.
- rhsobject
The right operand.
- spanOptional[Span]
The location of this operator in the source.
Returns#
- optvm.Expr
The result Expr of subtract operaton.
- tvm.tir.sum(expr, axis, where=None, init=None, *args)#
Create a sum expression over axis.
Parameters#
- exprPrimExpr
The source expression.
- axisIterVar
The reduction IterVar axis
- whereoptional, Expr
Filtering predicate of the reduction.
Returns#
- valuePrimExpr
The result value.
Example#
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this sum reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.sum represents tvm.te.sum or tvm.tir.sum. B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: sum_res = tvm.sum(m, n)
- tvm.tir.tan(x)[源代码]#
Take tan of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.tanh(x)[源代码]#
Take hyperbolic tanh of input x.
Parameters#
- xPrimExpr
Input argument.
Returns#
- yPrimExpr
The result.
- tvm.tir.trace(args, trace_action='tvm.default_trace_action')[源代码]#
Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.
Parameters#
- argslist of Expr or Buffers.
Positional arguments.
- trace_actionstr.
The name of the trace action.
Returns#
- callPrimExpr
The call expression.
See Also#
tvm.tir.call_packed : Creates packed function.
- tvm.tir.trunc(x, span=None)[源代码]#
Get truncated value of the input.
The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.
Parameters#
- xPrimExpr
Input argument.
- spanOptional[Span]
The location of this operator in the source code.
Returns#
- yPrimExpr
The result.
- tvm.tir.truncdiv(a, b, span=None)[源代码]#
Compute the truncdiv of two expressions.
Parameters#
- aPrimExpr
The left hand operand
- bPrimExpr
The right hand operand
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
Note#
This is the default integer division behavior in C.
- tvm.tir.truncmod(a, b, span=None)[源代码]#
Compute the truncmod of two expressions.
Parameters#
- aPrimExpr
The left hand operand
- bPrimExpr
The right hand operand
- spanOptional[Span]
The location of this operator in the source.
Returns#
- resPrimExpr
The result expression.
Note#
This is the default integer division behavior in C.
- tvm.tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)[源代码]#
Get head access address with memory access pattern info
Parameters#
- ptypeExpr
The data type of pointer.
- dataDType*
The data of pointer.
- offsetint
The offset of pointer.
- extentint
The extent of pointer.
- rw_maskint
The read write mask.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)[源代码]#
TVM intrinsic for tensor core bmma_sync operators
Parameters#
- fragment_dVar
The bwmma fragment_d.
- index_dExpr
The fragment_d index.
- fragment_aVar
The bwmma fragment_a.
- index_aExpr
The fragment_a index.
- fragment_bVar
The bwmma fragment_b.
- index_bExpr
The fragment_b index.
- fragment_cVar
The bwmma fragment_c.
- index_cExpr
The fragment_c index.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_check_return(expected, return_unexpected, nested_call)[源代码]#
Return new on stack dtype[num] Parameters ---------- expected : int
The expected return code.
- return_unexpectedint
The unexpected return code.
- nested_callPrimExpr
The call expression to check return.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_fill_fragment(fragment, m, n, k, index, value)[源代码]#
TVM intrinsic for tensor core fill_fragment operators
Parameters#
- fragmentVar
The wmma fragment
- mUIntImm
The shape of wmma fragment.
- nUIntImm
The shape of wmma fragment.
- kUIntImm
The shape of wmma fragment.
- indexExpr
The fragment index.
- valueExpr
The value to be filled in fragment.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)[源代码]#
TVM intrinsic for tensor core load operators
Parameters#
- fragmentVar
The wmma fragment.
- mUIntImm
The shape of wmma fragment.
- nUIntImm
The shape of wmma fragment.
- kUIntImm
The shape of wmma fragment.
- indexExpr
The fragment index.
- buffer_ptrExpr
The fragment buffer pointer.
- strideExpr
The fragment stride.
- layoutLiteral["row_major", "column_major"]
The fragment layout.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)[源代码]#
TVM intrinsic for tensor core mma_sync operators
Parameters#
- fragment_dVar
The wmma fragment_d.
- index_dExpr
The fragment_d index.
- fragment_aVar
The wmma fragment_a.
- index_aExpr
The fragment_a index.
- fragment_bVar
The wmma fragment_b.
- index_bExpr
The fragment_b index.
- fragment_cVar
The wmma fragment_c.
- index_cExpr
The fragment_c index.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_stack_alloca(dtype_str, num)[源代码]#
Return new on stack dtype[num]
Parameters#
- dtype_strstr
The data type of array.
- numint
The size of array.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)[源代码]#
Allocate a NDArray(DLTensor) on stack, return the handle
Parameters#
- dataExpr
The data of array.
- shapeExpr
The shape of array.
- stridesExpr
The strides of array.
- ndimExpr
The dimensions of array.
- arr_dtypeExpr
The data type of array.
- elem_offseExpr
The element offset of array.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_stack_make_shape(*args)[源代码]#
Allocate a shape tuple on stack, return the handle
Parameters#
- argsint
The tuple shape.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)[源代码]#
TVM intrinsic for tensor core store operators
Parameters#
- fragmentVar
The wmma fragment.
- mUIntImm
The shape of wmma fragment.
- nUIntImm
The shape of wmma fragment.
- kUIntImm
The shape of wmma fragment.
- indexExpr
The fragment index.
- buffer_ptrExpr
The fragment buffer pointer.
- strideExpr
The fragment stride.
- layoutLiteral["row_major", "column_major"]
The fragment layout.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_struct_get(arr, index, field, dtype)[源代码]#
Get struct field value in array
Parameters#
- dtypestr
The date type of the result.
- arrStructType*
The array of struct.
- indexint
The index of struct.
- fieldint
The field of struct.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_struct_set(arr, index, field, value)[源代码]#
Set value in struct field in array
Parameters#
- arrStructType*
The array of struct.
- indexint
The index of struct.
- fieldint
The field of struct.
- valueExpr
The value to be set in field.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_thread_allreduce(*freduce_args)[源代码]#
Perform allreduce inside threadblock.
Parameters#
- freduce_argsExpr
The args.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.tvm_throw_last_error()[源代码]#
Throw TVMGetLastError()
Returns#
- retPrimExpr
The return expression
- tvm.tir.tvm_tuple(*value)[源代码]#
Create a tuple structure in value field of AttrStmt
Parameters#
- valueExpr
The value in tuple.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.type_annotation(dtype)[源代码]#
Create a type annotation expression
Parameters#
- dtypeExpr
The data type.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.undef()[源代码]#
Returns an initialized but arbitrary value
Returns#
- callPrimExpr
The call expression.
- tvm.tir.vectorcombine(dtype, vec1, vec2)[源代码]#
Concat two vectors
Parameters#
- vec1list
The input vector.
- vec2list
The input vector.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.vectorhigh(dtype, vec)[源代码]#
Get the high level half of the vector
Parameters#
- dtypestr
The data type of the result.
- veclist
The input vector.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.vectorlow(dtype, vec)[源代码]#
Get the low level half of the vector
Parameters#
- dtypestr
The data type of the result.
- veclist
The input vector.
Returns#
- callPrimExpr
The call expression.
- tvm.tir.vscale()[源代码]#
Get the target's vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) Returns ------- call : PrimExpr
Call to the vscale intrinsic