# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, invalid-name, too-many-arguments
"""Operators used in TIR expression."""
from typing import Any, Optional, Union
import tvm._ffi
from tvm import tir
from tvm.ir import Array, Op, PrimExpr
from tvm.ir.base import Span
from tvm.runtime import const
from . import _ffi_api
from .buffer import Buffer
from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var
def _pack_buffer(buf, span=None):
"""Build intrinsics that packs the buffer."""
shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, span)
strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, span) if buf.strides else 0
pack_args = [
buf.data,
shape,
strides,
len(buf.shape),
const(0, dtype=buf.dtype),
buf.elem_offset,
]
return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span)
[文档]
def call_packed_lowered(*args, span=None):
"""Lowered version of call packed.
The argument to packed function can be Expr or Buffer.
The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc
will recieve an TVMArrayHandle whose content is valid during the callback period.
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
[文档]
def call_cpacked_lowered(*args, span=None):
"""Lowered version of call c-packed.
Same as call_packed, except that the first argument is the function name
(as in call_extern), and the last argument is the resource handle.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
[文档]
def call_packed(*args, span=None):
"""Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer.
The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc
will receive an TVMArrayHandle whose content is valid during the callback period.
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span)
[文档]
def call_cpacked(*args, span=None):
"""Build expression by call an external packed function.
Same as call_packed, except that the first argument is the function name
(as in call_extern), and the last argument is the resource handle.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
[文档]
def call_intrin(dtype, func_name, *args, span=None):
"""Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
the intrinsic translation rule.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The intrinsic function name.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return Call(dtype, func_name, args, span)
[文档]
def call_pure_extern(dtype, func_name, *args, span=None):
"""Build expression by calling a pure extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The extern function name.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span)
[文档]
def call_extern(dtype, func_name, *args, span=None):
"""Build expression by calling a extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The extern function name.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], span=span)
[文档]
def call_llvm_intrin(dtype, name, *args, span=None):
"""Build expression by calling a llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
if isinstance(name, str):
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
elif isinstance(name, IntImm):
llvm_id = name.value
else:
llvm_id = name
if llvm_id == 0:
raise ValueError(f"Unknown llvm intrinsic function {name}")
return call_intrin(
dtype,
Op.get("tir.call_llvm_intrin"),
tvm.tir.const(llvm_id, "uint32"),
*args,
span=span,
)
[文档]
def call_llvm_pure_intrin(dtype, name, *args, span=None):
"""Build expression by calling a pure llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
if isinstance(name, str):
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
elif isinstance(name, IntImm):
llvm_id = name.value
else:
llvm_id = name
if llvm_id == 0:
raise ValueError(f"Unknown llvm intrinsic function {name}")
return call_intrin(
dtype,
Op.get("tir.call_llvm_pure_intrin"),
tvm.tir.const(llvm_id, "uint32"),
*args,
span=span,
)
[文档]
def tvm_check_return(expected, return_unexpected, nested_call):
"""Return new on stack dtype[num]
Parameters
----------
expected : int
The expected return code.
return_unexpected : int
The unexpected return code.
nested_call : PrimExpr
The call expression to check return.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.tvm_check_return", expected, return_unexpected, nested_call)
[文档]
def tvm_stack_alloca(dtype_str, num):
"""Return new on stack dtype[num]
Parameters
----------
dtype_str : str
The data type of array.
num : int
The size of array.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
[文档]
def tvm_stack_make_shape(*args):
"""Allocate a shape tuple on stack, return the handle
Parameters
----------
args : int
The tuple shape.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
[文档]
def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
"""Allocate a NDArray(DLTensor) on stack, return the handle
Parameters
----------
data : Expr
The data of array.
shape : Expr
The shape of array.
strides : Expr
The strides of array.
ndim : Expr
The dimensions of array.
arr_dtype : Expr
The data type of array.
elem_offse : Expr
The element offset of array.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
)
[文档]
def assume(cond=None):
"""Provide a true statement that can be used for simplifications
Parameters
----------
cond : Expr
The constraint condition.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("bool", "tir.assume", cond)
[文档]
def undef():
"""Returns an initialized but arbitrary value
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.undef")
[文档]
def call_tir(global_var: tvm.ir.GlobalVar, *args):
"""Performs a call into another PrimFunc in the same IRModule
Returns
-------
call : PrimExpr
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
dtype = "void"
if global_var.checked_type is not None:
ret_type = global_var.checked_type.ret_type
if hasattr(ret_type, "dtype"):
dtype = ret_type.dtype
return Call(dtype=dtype, op=global_var, args=args)
[文档]
def start_profile_intrinsic(id):
"""Start profile intrinsic.
Parameters
----------
id : int
The intrinsic id.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.start_profile_intrinsic", id)
[文档]
def end_profile_intrinsic(id):
"""End profile intrinsic.
Parameters
----------
id : int
The intrinsic id.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.end_profile_intrinsic", id)
[文档]
def tvm_tuple(*value):
"""Create a tuple structure in value field of AttrStmt
Parameters
----------
value : Expr
The value in tuple.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_tuple", *value)
[文档]
def tvm_struct_get(arr, index, field, dtype):
"""Get struct field value in array
Parameters
----------
dtype : str
The date type of the result.
arr : StructType*
The array of struct.
index : int
The index of struct.
field : int
The field of struct.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field)
[文档]
def tvm_struct_set(arr, index, field, value):
"""Set value in struct field in array
Parameters
----------
arr : StructType*
The array of struct.
index : int
The index of struct.
field : int
The field of struct.
value : Expr
The value to be set in field.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value)
[文档]
def address_of(buffer_load, span=None):
"""Returns the address of an element in the buffer
Parameters
----------
buffer_load: BufferLoad
The buffer load.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.address_of", buffer_load, span=span)
[文档]
def lookup_param(param_name, span=None):
"""Returns the param by name
Parameters
----------
param_name : str
The name of param.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.lookup_param", param_name, span=span)
[文档]
def tvm_thread_allreduce(*freduce_args):
"""Perform allreduce inside threadblock.
Parameters
----------
freduce_args : Expr
The args.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
def tvm_thread_invariant(cond):
"""Mark condition as thread invariant.
Parameters
----------
cond : Expr
The condition.
Returns
-------
call : PrimExpr
The call expression.
"""
assert isinstance(cond, PrimExpr)
return call_intrin(cond.dtype, "tir.tvm_thread_invariant", cond)
def tvm_storage_sync(storage_scope):
"""Perform synchronization in specified scope.
Parameters
----------
storage_scope : str
The storage scope to perform synchronization.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.tvm_storage_sync", storage_scope)
def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
"""Exchange value between threads inside a warp.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
warp_id : PrimExpr
The source lane index to fetch value.
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size)
def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
"""Copy value from a lane with lower (by offset) index relative to caller.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = dst_lane_idx - src_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size
)
def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
"""Copy value from a lane with higher (by offset) index relative to caller.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = src_lane_idx - dst_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size
)
def tvm_warp_activemask():
"""Return a 32-bit mask indicates currently active threads in a calling warp.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("uint32", "tir.tvm_warp_activemask")
[文档]
def type_annotation(dtype):
"""Create a type annotation expression
Parameters
----------
dtype : Expr
The data type.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(dtype, "tir.type_annotation")
[文档]
def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
"""Get head access address with memory access pattern info
Parameters
----------
ptype : Expr
The data type of pointer.
data : DType*
The data of pointer.
offset : int
The offset of pointer.
extent : int
The extent of pointer.
rw_mask : int
The read write mask.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
[文档]
def tvm_throw_last_error():
"""Throw TVMGetLastError()
Returns
-------
ret : PrimExpr
The return expression
"""
return call_intrin("handle", "tir.tvm_throw_last_error")
[文档]
def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core load operators
Parameters
----------
fragment : Var
The wmma fragment.
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
buffer_ptr : Expr
The fragment buffer pointer.
stride : Expr
The fragment stride.
layout : Literal["row_major", "column_major"]
The fragment layout.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_load_matrix_sync",
fragment,
m,
n,
k,
index,
buffer_ptr,
stride,
layout,
)
[文档]
def tvm_mma_sync(
fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
):
"""TVM intrinsic for tensor core mma_sync operators
Parameters
----------
fragment_d : Var
The wmma fragment_d.
index_d : Expr
The fragment_d index.
fragment_a : Var
The wmma fragment_a.
index_a : Expr
The fragment_a index.
fragment_b : Var
The wmma fragment_b.
index_b : Expr
The fragment_b index.
fragment_c : Var
The wmma fragment_c.
index_c : Expr
The fragment_c index.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_mma_sync",
fragment_d,
index_d,
fragment_a,
index_a,
fragment_b,
index_b,
fragment_c,
index_c,
)
[文档]
def tvm_bmma_sync(
fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
):
"""TVM intrinsic for tensor core bmma_sync operators
Parameters
----------
fragment_d : Var
The bwmma fragment_d.
index_d : Expr
The fragment_d index.
fragment_a : Var
The bwmma fragment_a.
index_a : Expr
The fragment_a index.
fragment_b : Var
The bwmma fragment_b.
index_b : Expr
The fragment_b index.
fragment_c : Var
The bwmma fragment_c.
index_c : Expr
The fragment_c index.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_bmma_sync",
fragment_d,
index_d,
fragment_a,
index_a,
fragment_b,
index_b,
fragment_c,
index_c,
)
[文档]
def tvm_fill_fragment(fragment, m, n, k, index, value):
"""TVM intrinsic for tensor core fill_fragment operators
Parameters
----------
fragment : Var
The wmma fragment
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
value : Expr
The value to be filled in fragment.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_fill_fragment",
fragment,
m,
n,
k,
index,
value,
)
[文档]
def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core store operators
Parameters
----------
fragment : Var
The wmma fragment.
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
buffer_ptr : Expr
The fragment buffer pointer.
stride : Expr
The fragment stride.
layout : Literal["row_major", "column_major"]
The fragment layout.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_store_matrix_sync",
fragment,
m,
n,
k,
index,
buffer_ptr,
stride,
layout,
)
[文档]
def ptx_mma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
saturate,
operator=None,
):
"""TVM intrinsic for ptx tensor core mma instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment A.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
saturate : bool
The optional saturation at the output.
operator : Optional[Literal["xor", "and"]]
The 1-bit operator.
Returns
-------
call : PrimExpr
The call expression.
"""
if operator is None:
return call_intrin(
dtype,
"tir.ptx_mma",
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
saturate,
)
return call_intrin(
dtype,
"tir.ptx_mma",
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
saturate,
operator,
)
[文档]
def ptx_mma_sp(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
metadata,
meta_index,
sparse_selector,
saturate,
):
"""TVM intrinsic for sparse tensor core ptx instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of multiplicand fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment B.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
metadata : Expr
The metadata of operand.
meta_index : Expr
The metadata index of operand.
sparse_selector : Expr
The sparse selector indicating the thread that stores the metadata.
saturate : bool
The optional saturation at the output.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_mma_sp",
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
metadata,
meta_index,
sparse_selector,
saturate,
)
[文档]
def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
Parameters
----------
dtype : str
The data type of the result.
m : IntImm
The shape of mma fragment.
n : IntImm
The shape of mma fragment.
dst_ptr : Var
The destination pointer variable.
src_ptr : Var
The source pointer variable.
src_offset : Expr
The source offset.
dst_stride : Var
The destination stride.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.mma_store",
m,
n,
dst_ptr,
src_ptr,
src_offset,
dst_stride,
)
[文档]
def mma_fill(dtype, local_size, local_ptr, offset):
"""TVM intrinsic for zero-initalizing an MMA accumulation registor
Parameters
----------
dtype : str
The data type of the result.
local_size : IntImm
The number of elements.
local_ptr : Var
The destination pointer variable.
offset : Expr
The destination offset.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.mma_fill",
local_size,
local_ptr,
offset,
)
[文档]
def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
"""TVM intrinsic for ptx load matrix from shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
Parameters
----------
dtype : str
The data type of the result.
trans : bool
The matrix is loaded in column-major format.
num : IntImm
The number of matrices.
type : Literal[".b16"]
The data type of the matrices.
local_ptr : Var
The local pointer variable.
local_offset : Expr
The offset of local pointer.
smem_ptr : Var
The shared memory pointer variable.
smem_offset : Expr
The offset of shared memort pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_ldmatrix",
trans,
num,
type,
local_ptr,
local_offset,
smem_ptr,
smem_offset,
)
[文档]
def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
Parameters
----------
dtype : str
The data type of the result.
shared_ptr : Var
The shared memory pointer variable.
shared_offset : Expr
The offset of shared memory pointer.
global_ptr : Var
The global memory pointer variable.
global_offset : Expr
The offset of global memory pointer.
bytes : int
The data size to copy.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
)
[文档]
def ptx_cp_async_bulk(
dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id
):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
Parameters
----------
dtype : str
The data type of the result.
shared_ptr : Var
The shared memory pointer variable.
shared_offset : Expr
The offset of shared memory pointer.
global_ptr : Var
The global memory pointer variable.
global_offset : Expr
The offset of global memory pointer.
bytes : int
The data size to copy.
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_cp_async_bulk",
shared_ptr,
shared_offset,
global_ptr,
global_offset,
bytes,
barrier_id,
)
[文档]
def ptx_commit_group():
"""TVM intrinsic for ptx async copy commit
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_commit_group")
[文档]
def ptx_wait_group(num):
"""TVM intrinsic for ptx async copy wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group
Parameters
----------
num : int
The number of the most recent uncommitted pending cp.async groups to wait.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_group", num)
[文档]
def ptx_cp_async_barrier(barrier_id):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_id)
[文档]
def ptx_init_barrier_thread_count(barrier_id, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
thread_count : int
Number of threads expected to arrive at the barrier.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_init_barrier_thread_count", barrier_id, thread_count)
[文档]
def ptx_arrive_barrier(barrier_id):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_id)
[文档]
def ptx_arrive_barrier_expect_tx(barrier_id, byte_count):
"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
byte_count : int
Increases the tx count of the mbarrier object to track completion of
addtional async transactions.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier_expect_tx", barrier_id, byte_count)
[文档]
def ptx_wait_barrier(barrier_id):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_id)
[文档]
def create_barriers(barrier_count):
"""TVM intrinsic to create N barriers
Parameters
----------
barrier_count : int
The number of barriers to create.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.create_barriers", barrier_count)
[文档]
def make_filled_simdgroup_matrix(
d: Var,
index: PrimExpr,
value: PrimExpr,
col: int = 8,
row: int = 8,
):
"""Create a filled SIMDGroup matrix
Parameters
----------
d : var
The simdgroup var
index : PrimExpr
The index of the matrix.
value : PrimExpr
The value to fill.
col : int
The number of columns.
row : int
The number of rows.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.make_filled_simdgroup_matrix", d, index, value, col, row)
[文档]
def simdgroup_load(
d: Var,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
col: int = 8,
row: int = 8,
transpose_matrix: bool = False,
):
"""Load data from device memory or threadgroup memory to simdgroup
Parameters
----------
d : var
The simdgroup var
index : PrimExpr
The index of the matrix.
ptr : PrimExpr
The pointer.
stride : PrimExpr
The stride.
col : int
The number of columns.
row : int
The number of rows.
transpose_matrix : bool
Whether to transpose the matrix.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.simdgroup_load",
d,
index,
ptr,
stride,
col,
row,
transpose_matrix,
)
[文档]
def simdgroup_store(
d: PrimExpr,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
col: int = 8,
row: int = 8,
transpose_matrix: bool = False,
):
"""Store data from simdgroup to device memory or threadgroup memory
Parameters
----------
d : PrimExpr
The SIMDGroup.
index : PrimExpr
The index of the matrix.
ptr : PrimExpr
The pointer.
stride : PrimExpr
The stride.
col : int
The number of columns.
row : int
The number of rows.
transpose_matrix : bool
Whether to transpose the matrix.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle", "tir.simdgroup_store", d, index, ptr, stride, col, row, transpose_matrix
)
[文档]
def simdgroup_multiply_accumulate(
d: Var,
index_d: PrimExpr,
a: Var,
index_a: PrimExpr,
b: Var,
index_b: PrimExpr,
c: Var,
index_c: PrimExpr,
):
"""Multiply and accumulate two matrices in simdgroup
i.e. d = a * b + c
Parameters
----------
d : Var
The destination matrix.
index_d : PrimExpr
The index of the destination matrix.
a : Var
The first matrix.
index_a : PrimExpr
The index of the first matrix.
b : Var
The second matrix.
index_b : PrimExpr
The index of the second matrix.
c : Var
The third matrix.
index_c : PrimExpr
The index of the third matrix.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.simdgroup_multiply_accumulate",
d,
index_d,
a,
index_a,
b,
index_b,
c,
index_c,
)
[文档]
def vectorlow(dtype, vec):
"""Get the low level half of the vector
Parameters
----------
dtype : str
The data type of the result.
vec : list
The input vector.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(dtype, "tir.vectorlow", vec)
[文档]
def vectorhigh(dtype, vec):
"""Get the high level half of the vector
Parameters
----------
dtype : str
The data type of the result.
vec : list
The input vector.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(dtype, "tir.vectorhigh", vec)
[文档]
def vectorcombine(dtype, vec1, vec2):
"""Concat two vectors
Parameters
----------
vec1 : list
The input vector.
vec2 : list
The input vector.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
[文档]
def dp4a(vec1, vec2, acc=0):
"""Dot product of two int8x4 vectors and add an optional accumulator
Parameters
----------
vec1 : int8x4
The input vector.
vec2 : int8x4
The input vector.
acc : int32
The accumulator.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.dp4a", vec1, vec2, acc)
[文档]
def ret(val, span=None):
"""Create a tir return expression
Parameters
----------
val : Expr
The returned tir expression, whose data type is int, float or void pointer.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
ret : PrimExpr
The return expression
"""
return _ffi_api.ret(val, span)
[文档]
def any(*args, span=None):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
val = _ffi_api._OpOr(args[0], args[1], span) # type: ignore
for i in range(2, len(args)):
val = _ffi_api._OpOr(val, args[i], span) # type: ignore
return val
[文档]
def all(*args, span=None):
"""Create a new expression of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
val = _ffi_api._OpAnd(args[0], args[1], span) # type: ignore
for i in range(2, len(args)):
val = _ffi_api._OpAnd(val, args[i], span) # type: ignore
return val
@tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args):
print(list(args))
[文档]
def trace(args, trace_action="tvm.default_trace_action"):
"""Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the
runtime. The tracing value should come as last argument.
The trace action should be specified, by default
tvm.default_trace_action is used.
Parameters
----------
args : list of Expr or Buffers.
Positional arguments.
trace_action : str.
The name of the trace action.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
tvm.tir.call_packed : Creates packed function.
"""
if not isinstance(args, list):
raise Exception("tvm.tir.trace consumes the args as list type")
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
call_args.insert(0, trace_action)
return tvm.tir.Call(args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args)
[文档]
def min_value(dtype, span=None):
"""minimum value of dtype
Parameters
----------
dtype : str
The data type.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The minimum value of dtype.
"""
return _ffi_api.min_value(dtype, span) # type: ignore
[文档]
def max_value(dtype: str, span: Optional[Span] = None) -> Any:
"""maximum value of dtype
Parameters
----------
dtype : str
The data type.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The maximum value of dtype.
"""
return _ffi_api.max_value(dtype, span) # type: ignore
[文档]
def infinity(dtype: str, span: Optional[Span] = None) -> Any:
"""infinity value of dtype
Parameters
----------
dtype : str
The data type.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The infinity value of dtype.
"""
return _ffi_api.infinity(dtype, span) # type: ignore
[文档]
def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
"""infinity value of dtype
Parameters
----------
dtype : str
The data type.
value : PrimExpr
The input value.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The reinterpret cast value of dtype.
"""
return _ffi_api.reinterpret(dtype, value, span) # type: ignore
[文档]
def exp(x):
"""Take exponential of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.exp", x)
[文档]
def exp2(x):
"""Calculate 2**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.exp2", x)
[文档]
def exp10(x):
"""Calculate 10**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.exp10", x)
[文档]
def erf(x):
"""Take gauss error function of the input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.erf", x)
[文档]
def tanh(x):
"""Take hyperbolic tanh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.tanh", x)
[文档]
def sigmoid(x):
"""Quick function to get sigmoid
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.sigmoid", x)
[文档]
def log(x):
"""Take log of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.log", x)
[文档]
def log2(x):
"""Take log2 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.log2", x)
[文档]
def log10(x):
"""Take log10 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.log10", x)
[文档]
def log1p(x):
"""Take log(x + 1) with respect to input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.log1p", x)
[文档]
def tan(x):
"""Take tan of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.tan", x)
[文档]
def cos(x):
"""Take cos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.cos", x)
[文档]
def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.cosh", x)
[文档]
def acos(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.acos", x)
[文档]
def acosh(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.acosh", x)
[文档]
def sin(x):
"""Take sin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.sin", x)
[文档]
def sinh(x):
"""Take sinh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.sinh", x)
[文档]
def asin(x):
"""Take asin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.asin", x)
[文档]
def asinh(x):
"""Take asinh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.asinh", x)
[文档]
def atan(x):
"""Take atan of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.atan", x)
[文档]
def atanh(x):
"""Take atanh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.atanh", x)
[文档]
def atan2(x1, x2):
"""Take arctan2(x1, x2).
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x1 = tir.convert(x1)
x2 = tir.convert(x2)
return call_intrin(x1.dtype, "tir.atan2", x1, x2)
[文档]
def sqrt(x):
"""Take square root of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.sqrt", x)
[文档]
def rsqrt(x):
"""Take reciprocal of square root of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.rsqrt", x)
[文档]
def clz(x):
"""Count leading zero bits of an integer x.
Parameters
----------
x : PrimExpr
Input 32 or 64 bit integer.
The result is undefined if the input is 0.
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("int32", "tir.clz", x)
[文档]
def floor(x: PrimExprWithOp, span=None):
"""Take floor of float input x.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.floor(x, span) # type: ignore
[文档]
def ceil(x, span=None):
"""Take ceil of float input x.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.ceil(x, span) # type: ignore
[文档]
def trunc(x, span=None):
"""Get truncated value of the input.
The truncated value of the scalar x is the
nearest integer i which is closer to zero than x is.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.trunc(x, span) # type: ignore
[文档]
def abs(x, span=None):
"""Get absolute value of the input element-wise.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.abs(x, span) # type: ignore
[文档]
def bitwise_and(x, y, span=None):
"""Take bitwise and of two values
Parameters
----------
x : PrimExpr
Left operand
y : PrimExpr
Right operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _ffi_api.bitwise_and(x, y, span)
[文档]
def bitwise_not(x, span=None):
"""Take bitwise not of input value
Parameters
----------
x : PrimExpr
Input operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _ffi_api.bitwise_not(x, span)
[文档]
def bitwise_or(x, y, span=None):
"""Take bitwise or of two values
Parameters
----------
x : PrimExpr
Left operand
y : PrimExpr
Right operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _ffi_api.bitwise_or(x, y, span)
[文档]
def bitwise_xor(x, y, span=None):
"""Take bitwise xor of two values
Parameters
----------
x : PrimExpr
Left operand
y : PrimExpr
Right operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _ffi_api.bitwise_xor(x, y, span)
[文档]
def round(x, span=None):
"""Round elements of the array to the nearest integer.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.round(x, span) # type: ignore
[文档]
def nearbyint(x, span=None):
"""Round elements of the array to the nearest integer.
This intrinsic uses llvm.nearbyint instead of llvm.round
which is faster but will results different from te.round.
Notably nearbyint rounds according to the rounding mode,
whereas te.round (llvm.round) ignores that.
For differences between the two see:
https://en.cppreference.com/w/cpp/numeric/math/round
https://en.cppreference.com/w/cpp/numeric/math/nearbyint
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.nearbyint(x, span) # type: ignore
[文档]
def nextafter(x1, x2):
"""Return the next floating-point value after x1 towards x2.
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x1 = tir.convert(x1)
x2 = tir.convert(x2)
return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore
[文档]
def hypot(x1, x2):
"""Equivalent to sqrt(x1**2 + x2**2), element-wise.
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x1 = tir.convert(x1)
x2 = tir.convert(x2)
return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore
[文档]
def copysign(x1, x2):
"""Change the sign of x1 to that of x2, element-wise.
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x1 = tir.convert(x1)
x2 = tir.convert(x2)
return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore
[文档]
def ldexp(x1, x2):
"""Returns x1 * (2 ** x2).
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x1 = tir.convert(x1)
x2 = tir.convert(x2)
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore
[文档]
def likely(cond, span=None):
"""Mark condition as likely.
Parameters
----------
cond : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The marked expression.
"""
return _ffi_api.likely(cond, span) # type: ignore
[文档]
def isnan(x, span=None):
"""Check if input value is Nan.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isnan(x, span) # type: ignore
[文档]
def isnullptr(x, span=None):
"""Check if input value is nullptr.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("bool", "tir.isnullptr", x, span=span) # type: ignore
[文档]
def isfinite(x, span=None):
"""Check if input value is finite.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isfinite(x, span) # type: ignore
[文档]
def isinf(x, span=None):
"""Check if input value is infinite.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isinf(x, span) # type: ignore
[文档]
def power(x, y, span=None):
"""x power y
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
The exponent
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api._OpPow(x, y, span) # type: ignore
[文档]
def pow(x, y, span=None):
"""x power y
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
The exponent
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api._OpPow(x, y, span) # type: ignore
[文档]
def popcount(x):
"""Count the number of set bits in input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return call_intrin(x.dtype, "tir.popcount", x)
[文档]
def q_multiply_shift(x, y, q, s):
"""Execute a multiplication between two Q-numbers x and y
followed by a right shift s. The mathematical expression is:
out = round(x*y*2^-s)
More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
The rounding rule is to the nearest value, rounding half up
(i.e., round(x.1) = x and round (x.5) = x+1)
Parameters
----------
x : PrimExpr
First Q-number
y : PrimExpr
Second Q-number
q : PrimExpr
Number of fractional bits in x and y. Needs to be > 0
s : PrimExpr
Integer shift
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
[文档]
def q_multiply_shift_per_axis(
x: PrimExpr,
y: PrimExpr,
ls: PrimExpr,
rs: PrimExpr,
q: IntImm,
is_lshift_required: IntImm,
is_rshift_required: IntImm,
):
"""Execute a multiplication between two Q-numbers x and y
Parameters
----------
x : PrimExpr
First Q-number.
y : PrimExpr
Second Q-number.
ls : PrimExpr
Integer left shift.
rs : PrimExpr
Integer right shift.
q : IntImm
Number of fractional bits in x and y. Needs to be > 0.
is_lshift_required : IntImm
Whether we need to do left shift or not.
is_rshift_required : IntImm
Whether we need to do right shift or not.
Returns
-------
z : PrimExpr
The result.
"""
return call_intrin(
"int32",
"tir.q_multiply_shift_per_axis",
x,
y,
ls,
rs,
q,
is_lshift_required,
is_rshift_required,
)
[文档]
def shift_left(x, y, span=None):
"""Return the result of x left shifted by y bits.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api.left_shift(x, y, span)
[文档]
def shift_right(x, y, span=None):
"""Return the result of x right shifted by y bits.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api.right_shift(x, y, span)
[文档]
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
x = tir.convert(x)
y = tir.convert(y)
return call_intrin(x.dtype, "tir.fmod", x, y)
[文档]
def if_then_else(cond, t, f, span=None):
"""Conditional selection expression.
Parameters
----------
cond : PrimExpr
The condition
t : PrimExpr
The result expression if cond is true.
f : PrimExpr
The result expression if cond is false.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
result : Node
The result of conditional expression.
Note
----
Unlike Select, if_then_else will not execute
the branch that does not satisfy the condition.
You can use it to guard against out of bound access.
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
return _ffi_api._OpIfThenElse(cond, t, f, span) # type: ignore
[文档]
def div(a, b, span=None):
"""Compute a / b as in C/C++ semantics.
Parameters
----------
a : PrimExpr
The left hand operand, known to be non-negative.
b : PrimExpr
The right hand operand, known to be non-negative.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
When operands are integers, returns truncdiv(a, b, span).
"""
return _ffi_api._OpDiv(a, b, span) # type: ignore
[文档]
def indexdiv(a, b, span=None):
"""Compute floor(a / b) where a and b are non-negative.
Parameters
----------
a : PrimExpr
The left hand operand, known to be non-negative.
b : PrimExpr
The right hand operand, known to be non-negative.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _ffi_api._OpIndexDiv(a, b, span) # type: ignore
[文档]
def indexmod(a, b, span=None):
"""Compute the remainder of indexdiv. a and b are non-negative.
Parameters
----------
a : PrimExpr
The left hand operand, known to be non-negative.
b : PrimExpr
The right hand operand, known to be non-negative.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _ffi_api._OpIndexMod(a, b, span) # type: ignore
[文档]
def truncdiv(a, b, span=None):
"""Compute the truncdiv of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _ffi_api._OpTruncDiv(a, b, span) # type: ignore
[文档]
def truncmod(a, b, span=None):
"""Compute the truncmod of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _ffi_api._OpTruncMod(a, b, span) # type: ignore
[文档]
def floordiv(a, b, span=None):
"""Compute the floordiv of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
"""
return _ffi_api._OpFloorDiv(a, b, span) # type: ignore
[文档]
def floormod(a, b, span=None):
"""Compute the floormod of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
"""
return _ffi_api._OpFloorMod(a, b, span) # type: ignore
[文档]
def ceildiv(lhs, rhs, span=None):
"""Generic ceildiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
op : tvm.Expr
The result Expr of ceildiv operaton.
"""
return _ffi_api._OpCeilDiv(lhs, rhs, span) # type: ignore
[文档]
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
"""
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num - 1):
res = fcombine(res, args[i + 1])
return res
def _make_reduce(expr, axis, where=None, init=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = tir.convert(expr)
if init is not None:
init = tir.convert(init)
if isinstance(expr, Array):
size = len(expr)
lhs = []
rhs = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + "_" + str(i)
lhs.append(Var(lname, dtype))
rname = code.co_varnames[1] + "_" + str(i)
rhs.append(Var(rname, dtype))
if init is None:
init = []
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, tvm.ir.PrimExpr)
size = 1
dtype = expr.dtype
lvar = Var(code.co_varnames[0], dtype)
rvar = Var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = [lvar]
rhs = [rvar]
expr = [expr]
if init is not None:
init = [init]
combiner = CommReducer(lhs, rhs, result, id_elem)
if not isinstance(axis, (list, tuple, tvm.ir.Array)):
axis = [axis]
if where is None:
where = tir.convert(True)
if init is None:
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, []) for i in range(size))
else:
outputs = tuple(
tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size)
)
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, init=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where, init)
if where is None:
assert not args
assert init is None
return _reduce_directly(expr, axis)
elif init is None:
assert not args
return _reduce_directly(expr, axis, where)
else:
return _reduce_directly(expr, axis, where, init, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : PrimExpr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : PrimExpr
The result value.
Example
-------
.. code-block:: python
m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.{0} represents tvm.te.{0} or tvm.tir.{0}.
B = te.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
[文档]
def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
"""Backend function to allocate temporal workspace
Parameters
----------
device_type : int
The device type which the space will be allocated.
device_id : int
The device id which the space will be allocated.
nbytes : int
The size of the space requested.
dtype_code_hint : int
The type code of the array elements. Only used in certain backends such as OpenGL.
dtype_bits_hint : int
The type bits of the array elements. Only used in certain backends such as OpenGL.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.TVMBackendAllocWorkspace",
device_type,
device_id,
nbytes,
dtype_code_hint,
dtype_bits_hint,
)
[文档]
def TVMBackendFreeWorkspace(device_type, device_id, ptr):
"""Backend function to free temporal workspace.
Parameters
----------
device_type : int
The device type which the space will be allocated.
device_id : int
The device id which the space will be allocated.
ptr : Var
The result allocated space pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
def anylist_getitem(list_handle, index):
"""Returns an item from any list.
list_handle: Var
The handle to anylist
index : int
The index
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.anylist_getitem", list_handle, index)
def anylist_resetitem(list_handle, index):
"""Reset an item from any list.
list_handle: Var
The handle to anylist
index : int
The index
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int", "tir.anylist_resetitem", list_handle, index)
def anylist_setitem_call_packed(list_handle, index, func_name, *args):
"""Set anylist item by result of packed call.
list_handle: Var
The handle to anylist
index : int
The index
func_name: str
The name of the function to be called.
args:
Extra arguments
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"int", "tir.anylist_setitem_call_packed", list_handle, index, func_name, *args
)
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args):
"""Set anylist item by result of packed call.
list_handle: Var
The handle to anylist
index : int
The index
func_name: str
The name of the function to be called.
args:
Extra arguments
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"int", "tir.anylist_setitem_call_cpacked", list_handle, index, func_name, *args
)
[文档]
def vscale():
"""Get the target's vscale value. It will be lowered to llvm.vscale intrinsic
(https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic)
Returns
-------
call : PrimExpr
Call to the vscale intrinsic
"""
return call_intrin("int32", "tir.vscale")
[文档]
def get_active_lane_mask(dtype, base, limit):
"""
Calculate a predicate mask given an upper bound (limit) and a current value (base).
It will be lowered to the llvm.get.active.lane.mask intrinsic.
(https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
Parameters
----------
dtype : str
The data type of the result.
base : PrimExpr
An expression reprsenting the base.
limit : PrimExpr
An expression representing the limit.
"""
return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)
[文档]
def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr:
"""
Create a datatype dependent scalable expression.
Parameters
----------
dtype : Union[str, tvm.DataType]
Element data type.
min_size : int
The minimum size of the scalable vector in bits.
"""
if isinstance(dtype, str):
dtype = tvm.DataType(dtype)
return min_size // dtype.bits * vscale()
[文档]
def ignore_loop_partition(predicate) -> PrimExpr:
"""
Annotate a predicate not be considered as target condition of loop partition.
Parameters
----------
predicate : PrimExpr
The annotated predicate expression.
"""
return call_intrin("bool", "tir.ignore_loop_partition", predicate)
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") # type: ignore