tvm.tir

目录

tvm.tir#

Namespace for Tensor-level IR

Exceptions:

ScheduleError

Error that happens during TensorIR scheduling.

Classes:

Add(a, b[, span])

Add node.

Allocate(buffer_var, dtype, extents, ...[, ...])

Allocate node.

AllocateConst(buffer_var, dtype, extents, ...)

Allocate constant node.

And(a, b[, span])

And node.

Any([span])

Any node.

AssertStmt(condition, message, body[, span])

AssertStmt node.

AttrStmt(node, attr_key, value, body[, span])

AttrStmt node.

BijectiveLayout()

Bijective mapping for two layouts (src-layout and dst-layout).

Block(iter_vars, reads, writes, name_hint, body)

Block node.

BlockDependenceInfo(mod)

An object that helps build and query block level dependences using the 2 core objects BlockScope and StmtSRef

BlockRealize(iter_values, predicate, block)

BlockRealize node.

BlockScope()

An object corresponds to each block sref in the sref tree, which tracks the producer-consumer dependency between blocks.

Broadcast(value, lanes[, span])

Broadcast node.

Buffer()

Symbolic data buffer in TVM.

BufferLoad(buffer, indices[, span])

Buffer load node.

BufferRealize(buffer, bounds, condition, body)

Buffer realize node.

BufferRegion(buffer, region)

BufferRegion node.

BufferStore(buffer, value, indices[, span])

Buffer store node.

Call(dtype, op, args[, span])

Call node.

CallEffectKind()

Possible kinds of Call effects.

Cast(dtype, value[, span])

Cast expression.

CommReducer(lhs, rhs, result, identity_element)

Commutative reduce operator

DataProducer()

DeclBuffer(buffer, body[, span])

DeclBuffer node.

Div(a, b[, span])

Div node.

EQ(a, b[, span])

EQ node.

Evaluate(value[, span])

Evaluate node.

FloatImm(dtype, value[, span])

Float constant.

FloorDiv(a, b[, span])

FloorDiv node.

FloorMod(a, b[, span])

FloorMod node.

For(loop_var, min, extent, kind, body[, ...])

For node.

ForKind(value[, names, module, qualname, ...])

The kind of the for loop.

GE(a, b[, span])

GE node.

GT(a, b[, span])

GT node.

IfThenElse(condition, then_case, else_case)

IfThenElse node.

IndexMap(initial_indices, final_indices, ...)

A mapping from multi-dimensional indices to another set of multi-dimensional indices

IntImm(dtype, value[, span])

Int constant.

IterVar(dom, var, iter_type[, thread_tag, span])

Represent iteration variable.

LE(a, b[, span])

LE node.

LT(a, b[, span])

LT node.

Layout()

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis.

Let(var, value, body[, span])

Let node.

LetStmt(var, value, body[, span])

LetStmt node.

MatchBufferRegion(buffer, source)

MatchBufferRegion node.

Max(a, b[, span])

Max node.

Min(a, b[, span])

Min node.

Mod(a, b[, span])

Mod node.

Mul(a, b[, span])

Mul node.

NE(a, b[, span])

NE node.

Not(a[, span])

Not node.

Or(a, b[, span])

Or node.

Prefetch(buffer, bounds[, span])

Prefetch node.

PrimFunc(params, body[, ret_type, ...])

A function declaration expression.

ProducerLoad(producer, indices[, span])

Producer load node.

ProducerRealize(producer, bounds, condition, ...)

ProducerRealize node.

ProducerStore(producer, value, indices[, span])

ProducerStore node.

Ramp(base, stride, lanes[, span])

Ramp node.

Reduce(combiner, src, rdom, condition, ...)

Reduce node.

Schedule(mod, *[, seed, debug_mask, ...])

The user-facing schedule class

ScheduleState(mod, *[, debug_mask, enable_check])

The state of scheduling, which exposes a Replace method as the primary resort for all the scheduling primitives to manipulate the TensorIR.

Select(condition, true_value, false_value[, ...])

Select node.

SeqStmt(seq[, span])

Sequence of statements.

Shuffle(vectors, indices[, span])

Shuffle node.

SizeVar(name, dtype[, span])

Symbolic variable to represent a tensor index size

Stmt()

Base class of all the statements.

StmtSRef()

An object that refers to schedulable elements in the TensorIR, aka "sref".

StringImm(value[, span])

String constant.

Sub(a, b[, span])

Sub node.

TensorIntrin(desc, impl)

A tensor intrinsic.

Var(name, dtype[, span])

Symbolic variable.

While(condition, body[, span])

While node.

Functions:

TVMBackendAllocWorkspace(device_type, ...)

Backend function to allocate temporal workspace

TVMBackendFreeWorkspace(device_type, ...)

Backend function to free temporal workspace.

abs(x[, span])

Get absolute value of the input element-wise.

acos(x)

Take acos of input x.

acosh(x)

Take acos of input x.

add(lhs, rhs[, span])

Generic add operator.

address_of(buffer_load[, span])

Returns the address of an element in the buffer

all(*args[, span])

Create a new expression of the intersection of all conditions in the

any(*args[, span])

Create a new experssion of the union of all conditions in the arguments

asin(x)

Take asin of input x.

asinh(x)

Take asinh of input x.

assume([cond])

Provide a true statement that can be used for simplifications

atan(x)

Take atan of input x.

atan2(x1, x2)

Take arctan2(x1, x2).

atanh(x)

Take atanh of input x.

bijective_layout(src_layout, dst_layout)

Create a bijective layout mapping.

bitwise_and(x, y[, span])

Take bitwise and of two values

bitwise_not(x[, span])

Take bitwise not of input value

bitwise_or(x, y[, span])

Take bitwise or of two values

bitwise_xor(x, y[, span])

Take bitwise xor of two values

call_cpacked(*args[, span])

Build expression by call an external packed function.

call_cpacked_lowered(*args[, span])

Lowered version of call c-packed.

call_extern(dtype, func_name, *args[, span])

Build expression by calling a extern function.

call_intrin(dtype, func_name, *args[, span])

Build expression by calling an intrinsic function.

call_llvm_intrin(dtype, name, *args[, span])

Build expression by calling a llvm intrinsic function

call_llvm_pure_intrin(dtype, name, *args[, span])

Build expression by calling a pure llvm intrinsic function

call_packed(*args[, span])

Build expression by call an external packed function.

call_packed_lowered(*args[, span])

Lowered version of call packed.

call_pure_extern(dtype, func_name, *args[, span])

Build expression by calling a pure extern function.

call_tir(global_var, *args)

Performs a call into another PrimFunc in the same IRModule

ceil(x[, span])

Take ceil of float input x.

ceildiv(lhs, rhs[, span])

Generic ceildiv operator.

clz(x)

Count leading zero bits of an integer x.

comm_reducer(fcombine, fidentity[, name])

Create a commutative reducer for reduction.

copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

cos(x)

Take cos of input x.

cosh(x)

Take cosh of input x.

create_barriers(barrier_count)

TVM intrinsic to create N barriers

decl_buffer(shape[, dtype, name, data, ...])

Declare a new symbolic buffer.

div(a, b[, span])

Compute a / b as in C/C++ semantics.

end_profile_intrinsic(id)

End profile intrinsic. Parameters ---------- id : int The intrinsic id. Returns ------- call : PrimExpr The call expression.

erf(x)

Take gauss error function of the input x.

exp(x)

Take exponential of input x.

exp10(x)

Calculate 10**x

exp2(x)

Calculate 2**x

floor(x[, span])

Take floor of float input x.

floordiv(a, b[, span])

Compute the floordiv of two expressions.

floormod(a, b[, span])

Compute the floormod of two expressions.

fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

get_active_lane_mask(dtype, base, limit)

Calculate a predicate mask given an upper bound (limit) and a current value (base).

hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

if_then_else(cond, t, f[, span])

Conditional selection expression.

indexdiv(a, b[, span])

Compute floor(a / b) where a and b are non-negative.

indexmod(a, b[, span])

Compute the remainder of indexdiv.

infinity(dtype[, span])

infinity value of dtype

isfinite(x[, span])

Check if input value is finite.

isinf(x[, span])

Check if input value is infinite.

isnan(x[, span])

Check if input value is Nan.

isnullptr(x[, span])

Check if input value is nullptr.

layout(layout_str[, dtype])

Create a layout node from a string.

ldexp(x1, x2)

Returns x1 * (2 ** x2).

likely(cond[, span])

Mark condition as likely.

log(x)

Take log of input x.

log10(x)

Take log10 of input x.

log1p(x)

Take log(x + 1) with respect to input x.

log2(x)

Take log2 of input x.

lookup_param(param_name[, span])

Returns the param by name

max(expr, axis[, where, init])

Create a max expression over axis.

max_value(dtype[, span])

maximum value of dtype

min(expr, axis[, where, init])

Create a min expression over axis.

min_value(dtype[, span])

minimum value of dtype

mma_fill(dtype, local_size, local_ptr, offset)

TVM intrinsic for zero-initalizing an MMA accumulation registor

mma_store(dtype, m, n, dst_ptr, src_ptr, ...)

TVM intrinsic for storing the result of PTX MMA into a destination pointer

multiply(lhs, rhs[, span])

Generic multiply operator.

nearbyint(x[, span])

Round elements of the array to the nearest integer.

nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

popcount(x)

Count the number of set bits in input x.

pow(x, y[, span])

x power y

power(x, y[, span])

x power y

ptx_arrive_barrier(barrier_id)

TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

ptx_arrive_barrier_expect_tx(barrier_id, ...)

TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation

ptx_commit_group()

TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

ptx_cp_async(dtype, shared_ptr, ...)

TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

ptx_cp_async_barrier(barrier_id)

TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

ptx_cp_async_bulk(dtype, shared_ptr, ...)

TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk

ptx_init_barrier_thread_count(barrier_id, ...)

TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

ptx_ldmatrix(dtype, trans, num, type, ...)

TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

ptx_mma(dtype, shape, A_layout, B_layout, ...)

TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

ptx_mma_sp(dtype, shape, A_layout, B_layout, ...)

TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

ptx_wait_barrier(barrier_id)

TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

ptx_wait_group(num)

TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

q_multiply_shift(x, y, q, s)

Execute a multiplication between two Q-numbers x and y followed by a right shift s.

q_multiply_shift_per_axis(x, y, ls, rs, q, ...)

Execute a multiplication between two Q-numbers x and y

reinterpret(dtype, value[, span])

infinity value of dtype

ret(val)

Create a tir return expression

round(x[, span])

Round elements of the array to the nearest integer.

rsqrt(x)

Take reciprocal of square root of input x.

shift_left(x, y[, span])

Return the result of x left shifted by y bits.

shift_right(x, y[, span])

Return the result of x right shifted by y bits.

sigmoid(x)

Quick function to get sigmoid

sin(x)

Take sin of input x.

sinh(x)

Take sinh of input x.

sqrt(x)

Take square root of input x.

start_profile_intrinsic(id)

Start profile intrinsic. Parameters ---------- id : int The intrinsic id. Returns ------- call : PrimExpr The call expression.

stmt_list(stmt)

Make list of stmt from blocks.

stmt_seq(*args)

Make sequence of statements

subtract(lhs, rhs[, span])

Generic subtract operator.

sum(expr, axis[, where, init])

Create a sum expression over axis.

tan(x)

Take tan of input x.

tanh(x)

Take hyperbolic tanh of input x.

trace(args[, trace_action])

Trace tensor data at the runtime.

trunc(x[, span])

Get truncated value of the input.

truncdiv(a, b[, span])

Compute the truncdiv of two expressions.

truncmod(a, b[, span])

Compute the truncmod of two expressions.

tvm_access_ptr(ptype, data, offset, extent, ...)

Get head access address with memory access pattern info

tvm_bmma_sync(fragment_d, index_d, ...)

TVM intrinsic for tensor core bmma_sync operators

tvm_check_return(expected, ...)

Return new on stack dtype[num] Parameters ---------- expected : int The expected return code. return_unexpected : int The unexpected return code. nested_call : PrimExpr The call expression to check return. Returns ------- call : PrimExpr The call expression.

tvm_fill_fragment(fragment, m, n, k, index, ...)

TVM intrinsic for tensor core fill_fragment operators

tvm_load_matrix_sync(fragment, m, n, k, ...)

TVM intrinsic for tensor core load operators

tvm_mma_sync(fragment_d, index_d, ...)

TVM intrinsic for tensor core mma_sync operators

tvm_stack_alloca(dtype_str, num)

Return new on stack dtype[num]

tvm_stack_make_array(data, shape, strides, ...)

Allocate a NDArray(DLTensor) on stack, return the handle

tvm_stack_make_shape(*args)

Allocate a shape tuple on stack, return the handle

tvm_store_matrix_sync(fragment, m, n, k, ...)

TVM intrinsic for tensor core store operators

tvm_struct_get(arr, index, field, dtype)

Get struct field value in array

tvm_struct_set(arr, index, field, value)

Set value in struct field in array

tvm_thread_allreduce(*freduce_args)

Perform allreduce inside threadblock.

tvm_throw_last_error()

Throw TVMGetLastError()

tvm_tuple(*value)

Create a tuple structure in value field of AttrStmt

type_annotation(dtype)

Create a type annotation expression

undef()

Returns an initialized but arbitrary value

vectorcombine(dtype, vec1, vec2)

Concat two vectors

vectorhigh(dtype, vec)

Get the high level half of the vector

vectorlow(dtype, vec)

Get the low level half of the vector

vscale()

Get the target's vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) Returns ------- call : PrimExpr Call to the vscale intrinsic.

exception tvm.tir.ScheduleError[源代码]#

Error that happens during TensorIR scheduling.

class tvm.tir.Add(a, b, span=None)[源代码]#

Add node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Allocate(buffer_var, dtype, extents, condition, body, annotations=None, span=None)[源代码]#

Allocate node.

Parameters#

buffer_varVar

The buffer variable.

dtypestr

The data type of the buffer.

extentslist of Expr

The extents of the allocate

conditionPrimExpr

The condition.

bodyStmt

The body statement.

annotations: Optional[Mapping[str, Object]]

Additional annotation hints

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations=None, span=None)[源代码]#

Allocate constant node.

Parameters#

buffer_varVar

The buffer variable.

dtypestr

The data type of the buffer.

extentslist of Expr

The extents of the allocate

data_or_idxUnion[NDArray, int]

If an NDArray, this is the const data associated with the constant. If an integer, this is the index into the “constants” attribute of the IRModule that contains the AllocateConst.

bodyStmt

The body statement.

annotationsOptional[Mapping[str, Object]]

Additional annotations about the allocation.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.And(a, b, span=None)[源代码]#

And node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Any(span=None)[源代码]#

Any node.

spanOptional[Span]

The location of this expression in the source code.

参数:

span (Span | None)

class tvm.tir.AssertStmt(condition, message, body, span=None)[源代码]#

AssertStmt node.

Parameters#

conditionPrimExpr

The assert condition.

messagePrimExpr

The error message.

bodytvm.tir.Stmt

The body statement.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.AttrStmt(node, attr_key, value, body, span=None)[源代码]#

AttrStmt node.

Parameters#

nodeObject

The node to annotate the attribute

attr_keystr

Attribute type key.

valuePrimExpr

The value of the attribute

bodyStmt

The body statement.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.BijectiveLayout[源代码]#

Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other.

Do not construct directly, use bijective_layout instead. See the documentation of bijective_layout for more details.

Parameters#

src_layoutstr or Layout

source layout.

dst_layoutstr or Layout

destination layout.

See Also#

bijective_layout : Declare a layout

Methods:

backward_index(index)

Given the indices of the dst-layout, infer the src index.

backward_shape(shape)

Given the shape of the dst-layout, infer the src shape.

forward_index(index)

Given the indices of the src-layout, infer the dst index.

forward_shape(shape)

Given the shape of the src-layout, infer the dst shape.

backward_index(index)[源代码]#

Given the indices of the dst-layout, infer the src index.

Parameters#

index: Array of Expr

The indices in dst-layout.

Returns#

src_index: Array of Expr

The inferred indices in src-layout.

backward_shape(shape)[源代码]#

Given the shape of the dst-layout, infer the src shape.

Parameters#

shape: Array of Expr

The shape in dst-layout.

Returns#

src_shape: Array of Expr

The inferred shape in src-layout.

forward_index(index)[源代码]#

Given the indices of the src-layout, infer the dst index.

Parameters#

index: Array of Expr

The indices in src-layout.

Returns#

dst_index: Array of Expr

The inferred indices in dst-layout.

forward_shape(shape)[源代码]#

Given the shape of the src-layout, infer the dst shape.

Parameters#

shape: Array of Expr

The shape in src-layout.

Returns#

dst_shape: Array of Expr

The inferred shape in dst-layout.

class tvm.tir.Block(iter_vars, reads, writes, name_hint, body, init=None, alloc_buffers=None, match_buffers=None, annotations=None, span=None)[源代码]#

Block node.

Parameters#

iter_varsList[IterVar]

The block Variable.

readsList[BufferRegion]

The read buffer regions of the block.

writes: List[BufferRegion]

The write buffer regions of the block.

name_hint: str

the name_hint of the block.

body: Stmt

The body of the block.

init: Optional[Stmt]

The init block of the reduction block

alloc_buffers: Optional[list[Buffer]]

The buffer allocations

match_buffers: Optional[List[MatchBufferRegion]]

The subregion buffer match

annotations: Optional[Mapping[str, Object]]

Additional annotation hints.

spanOptional[Span]

The location of this block in the source code.

参数:
class tvm.tir.BlockDependenceInfo(mod)[源代码]#

An object that helps build and query block level dependences using the 2 core objects BlockScope and StmtSRef

The data structures exposed are: 1) sref2scope: Mapping from the srefs to its corresponding BlockScope 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs

Note that this object does not store SRefs to loops as the purpose is only to expose block level dependences. This provides the advantage that the scope block (parent block) for a given block sref can be directly accessed as sref->parent

Methods:

get_block_scope(block_sref)

Get the BlockScope correpsonding to the block sref

get_sref(block)

Return the corresponding sref that points to the block

参数:

mod (IRModule)

get_block_scope(block_sref)[源代码]#

Get the BlockScope correpsonding to the block sref

Parameters#

block_srefStmtSRef

The block sref to be retrieved

Returns#

scopeStmtSRef

The corresponding BlockScope

参数:

block_sref (StmtSRef)

返回类型:

BlockScope

get_sref(block)[源代码]#

Return the corresponding sref that points to the block

Parameters#

stmtBlock

The block for which the sref is to be retrived

Returns#

srefStmtSRef

The corresponding sref

参数:

block (Block)

返回类型:

StmtSRef | None

class tvm.tir.BlockRealize(iter_values, predicate, block, span=None)[源代码]#

BlockRealize node.

Parameters#

iter_valuesList[PrimExpr]

The binding values of the block var.

predicateUnion[PrimExpr, bool]

The predicate of the block.

blockBlock

The block to realize

spanOptional[Span]

The location of this block_realize in the source code.

参数:
class tvm.tir.BlockScope[源代码]#

An object corresponds to each block sref in the sref tree, which tracks the producer-consumer dependency between blocks.

Glossary:

  • Block scope: A contiguous subtree of the sref tree, rooted at each block sref, whose components are:

    • scope root: a block sref

    • internal srefs: loop srefs

    • scope leaves: block srefs

  • Child block: The scope leaf blocks under the scope root or a specific internal sref

Methods:

get_deps_by_dst(block)

Get all dependencies whose dst is the target block.

get_deps_by_src(block)

Get all dependencies whose src is the target`block`.

get_deps_by_dst(block)[源代码]#

Get all dependencies whose dst is the target block.

Parameters#

block: StmtSRef

The queried block

Returns#

blocks: List[Dependency]

The dependencies

参数:

block (StmtSRef)

返回类型:

List[Dependency]

get_deps_by_src(block)[源代码]#

Get all dependencies whose src is the target`block`.

Parameters#

block: StmtSRef

The queried block

Returns#

blocks: List[Dependency]

The dependencies

参数:

block (StmtSRef)

返回类型:

List[Dependency]

class tvm.tir.Broadcast(value, lanes, span=None)[源代码]#

Broadcast node.

Parameters#

valuePrimExpr

The value of the expression.

lanesPrimExpr

The lanes of the expression.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Buffer[源代码]#

Symbolic data buffer in TVM.

Buffer provide a way to represent data layout specialization of data structure in TVM.

Do not construct directly, use decl_buffer() instead. See the documentation of decl_buffer() for more details.

See Also#

decl_buffer : Declare a buffer

Methods:

access_ptr(access_mask[, ptr_type, ...])

Get an access pointer to the head of buffer.

get_flattened_buffer()

Generate a Buffer that is a flattened version of this buffer.

offset_of(indices)

Determine the offset of the provided indices in the flattened buffer.

scope()

Return the storage scope associated with this buffer. Returns ------- scope : str The storage scope associated with this buffer.

vload(begin[, dtype])

Generate an Expr that loads dtype from begin index.

vstore(begin, value)

Generate a Stmt that store value into begin index.

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)[源代码]#

Get an access pointer to the head of buffer.

This is the recommended method to get buffer data ptress when interacting with external functions.

Parameters#

access_maskint

The access pattern MASK. Indicate whether the access will read or write to the data content.

ptr_typestr, optional

The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.

content_lanes: int, optional

The number of lanes for the data type. This value is greater than one for vector types.

offset: Expr, optional

The offset of pointer. We can use it to offset by the number of elements from the address of ptr.

extent: Expr, optional

The extent of pointer.

Examples#

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
get_flattened_buffer()[源代码]#

Generate a Buffer that is a flattened version of this buffer.

Returns#

flattenedBuffer

The corresponding flat buffer.

offset_of(indices)[源代码]#

Determine the offset of the provided indices in the flattened buffer.

Parameters#

indices : Union[PrimExpr, List[PrimExpr]]

The indices of the element in the original buffer.

Returns#

flattened_indices: List[PrimExpr]

The offset indices of the element in the flattened buffer.

scope()[源代码]#

Return the storage scope associated with this buffer. Returns ——- scope : str

The storage scope associated with this buffer.

vload(begin, dtype=None)[源代码]#

Generate an Expr that loads dtype from begin index.

Parameters#

beginArray of Expr

The beginning index in unit of Buffer.dtype

dtypestr

The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype

Returns#

loadExpr

The corresponding load expression.

vstore(begin, value)[源代码]#

Generate a Stmt that store value into begin index.

Parameters#

beginArray of Expr

The beginning index in unit of Buffer.dtype

valueExpr

The value to be stored.

Returns#

storeStmt

The corresponding store stmt.

class tvm.tir.BufferLoad(buffer, indices, span=None)[源代码]#

Buffer load node.

Parameters#

bufferBuffer

The buffer to be loaded.

indicesList[PrimExpr]

The buffer indices.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.BufferRealize(buffer, bounds, condition, body, span=None)[源代码]#

Buffer realize node.

Parameters#

bufferBuffer

The buffer.

boundsList[Range]

The value we to be stored.

conditionPrimExpr

The realize condition.

bodyStmt

The body of the statement.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.BufferRegion(buffer, region)[源代码]#

BufferRegion node.

Parameters#

bufferBuffer

The buffer of the buffer region

regionList[Range]

The region array of the buffer region

参数:
class tvm.tir.BufferStore(buffer, value, indices, span=None)[源代码]#

Buffer store node.

Parameters#

bufferBuffer

The buffer.

valuePrimExpr

The value we to be stored.

indicesList[PrimExpr]

The indices location to be stored.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.Call(dtype, op, args, span=None)[源代码]#

Call node.

Parameters#

dtypestr

The return data type

opUnion[Op, str]

The function to be called, or the name to the global tvm.Op

argslist of Expr

The input arguments to the call

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.CallEffectKind[源代码]#

Possible kinds of Call effects.

class tvm.tir.Cast(dtype, value, span=None)[源代码]#

Cast expression.

Parameters#

dtypestr

The data type

valuePrimExpr

The value of the function.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.CommReducer(lhs, rhs, result, identity_element, span=None)[源代码]#

Commutative reduce operator

Parameters#

lhsList[Var]

The left arguments of the reducer.

rhsList[Var]

The right arguments of the reducer.

resultList[PrimExpr]

The reduction results.

identity_elementList[PrimExpr]

The identity elements.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.DataProducer[源代码]#
class tvm.tir.DeclBuffer(buffer, body, span=None)[源代码]#

DeclBuffer node.

Parameters#

buffer: Buffer

The buffer being declared.

body: Stmt

The body statement to be executed.

span: Optional[Span]

The location of this DeclBuffer in the source code.

参数:
class tvm.tir.Div(a, b, span=None)[源代码]#

Div node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.EQ(a, b, span=None)[源代码]#

EQ node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Evaluate(value, span=None)[源代码]#

Evaluate node.

Parameters#

valuePrimExpr

The expression to be evaluated.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.FloatImm(dtype, value, span=None)[源代码]#

Float constant.

Parameters#

dtypestr

The data type

valuefloat

The constant value.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.FloorDiv(a, b, span=None)[源代码]#

FloorDiv node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.FloorMod(a, b, span=None)[源代码]#

FloorMod node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.For(loop_var, min, extent, kind, body, thread_binding=None, annotations=None, span=None)[源代码]#

For node.

Parameters#

loop_varVar

The loop variable.

minPrimExpr

The beginning value.

extentPrimExpr

The length of the loop.

kindForKind

The type of the for.

bodyStmt

The body statement.

thread_binding: Optional[tir.IterVar]

The thread this loop binds to. Only valid if kind is ThreadBinding

annotations: Optional[Mapping[str, Object]]

Additional annotation hints.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.ForKind(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#

The kind of the for loop.

note#

ForKind can change the control flow semantics of the loop and need to be considered in all TIR passes.

class tvm.tir.GE(a, b, span=None)[源代码]#

GE node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.GT(a, b, span=None)[源代码]#

GT node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.IfThenElse(condition, then_case, else_case, span=None)[源代码]#

IfThenElse node.

Parameters#

conditionPrimExpr

The expression

then_caseStmt

The statement to execute if condition is true.

else_caseOptional[Stmt]

The statement to execute if condition is false.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.IndexMap(initial_indices, final_indices, inverse_index_map)[源代码]#

A mapping from multi-dimensional indices to another set of multi-dimensional indices

Parameters#

initial_indicesList[Var]

Variables representing the indices prior to remapping.

final_indicesList[PrimExpr]

Expressions defining the indices after remapping.

inverse_index_mapUnion[Callable, Optional[IndexMap]]

The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Methods:

from_func(mapping_function[, ndim, ...])

Create an index map from a function

from_func_with_separators(mapping_function)

Create an index map from a function

inverse(shape)

Return the inverse of the map

is_equivalent_to(other_map)

Return if the index maps are equivalent.

map_indices(indices)

Apply the index map to a set of indices

map_ndarray(arr_src)

Apply thie index map to transform the layout of the input NDArray

map_shape(shape)

Apply the index map to a buffer shape

non_surjective_inverse(shape)

Return the inverse of the map

static from_func(mapping_function, ndim=None, inverse_index_map=None, *, index_dtype='int64')[源代码]#

Create an index map from a function

Parameters#

mapping_function : Callable

The function to map from source indices to target indices. The function should accept tir.Var parameters and return a either a tir.PrimExpr, or a list of tir.PrimExpr. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.

ndim: Optional[int]

The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

inverse_index_mapUnion[Callable, Optional[IndexMap]]

The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Returns#

index_map: IndexMap

Returns an IndexMap representing the mapping_function.

参数:
static from_func_with_separators(mapping_function, ndim=None, inverse_index_map=None, *, index_dtype='int64')[源代码]#

Create an index map from a function

Parameters#

mapping_function : Callable

The function to map from source indices to target indices. The function should accept tir.Var parameters and return either a tir.PrimExpr or a list. Each element of the returned list should be either a tir.PrimExpr or the object IndexMap.AXIS_SEPARATOR. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.

ndim: Optional[int]

The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

inverse_index_mapUnion[Callable, Optional[IndexMap]]

The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

index_dtypestr

The default index dtype to use for input iters in the mapping function.

Returns#

ret: Tuple[IndexMap, List[int]]

Returns a tuple whose first element is an IndexMap representing the mapping_function, and whose second index is a list of indices at which IndexMap.AXIS_SEPARATOR occurred.

参数:
inverse(shape)[源代码]#

Return the inverse of the map

Throws an error if the function is not bijective.

Parameters#

shape: List[Union[Range,PrimExpr]]

The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range.

Returns#

inverse : IndexMap

The inverse

参数:

shape (List[Range | PrimExpr])

返回类型:

IndexMap

is_equivalent_to(other_map)[源代码]#

Return if the index maps are equivalent.

Parameters#

other_map: IndexMap

The IndexMap to which the comparison should be made.

Returns#

is_equivalent: bool

True if the two mappings represent the same transformation, otherwise False

参数:

other_map (IndexMap)

返回类型:

bool

map_indices(indices)[源代码]#

Apply the index map to a set of indices

Parameters#

indicesList[PrimExpr]

The indices to be mapped

Returns#

resultList[PrimExpr]

The mapped indices

参数:

indices (List[PrimExpr])

返回类型:

List[PrimExpr]

map_ndarray(arr_src)[源代码]#

Apply thie index map to transform the layout of the input NDArray

Parameters#

arr_srcruntime.NDArray

The NDArray to be transformed

Returns#

arr_dstruntime.NDArray

The transformed NDArray

参数:

arr_src (NDArray)

返回类型:

NDArray

map_shape(shape)[源代码]#

Apply the index map to a buffer shape

Parameters#

shapeList[PrimExpr]

The buffer shape to be mapped

Returns#

resultList[PrimExpr]

The mapped shape

参数:

shape (List[PrimExpr])

返回类型:

List[PrimExpr]

non_surjective_inverse(shape)[源代码]#

Return the inverse of the map

Can be applied to transformations that introduce padding.

Parameters#

shape: List[Union[Range,PrimExpr]]

The region over which the inverse should be determined. Used for determining the predicate.

Returns#

result : Tuple[IndexMap, PrimExpr]

The inverse, and a predicate for which the inverse maps to a valid index in the input range.

Examples#

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
参数:

shape (List[Range | PrimExpr])

返回类型:

Tuple[IndexMap, PrimExpr]

参数:
class tvm.tir.IntImm(dtype, value, span=None)[源代码]#

Int constant.

Parameters#

dtypestr

The data type

valueint

The constant value.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.IterVar(dom, var, iter_type, thread_tag='', span=None)[源代码]#

Represent iteration variable.

IterVar represents axis iterations in the computation.

Parameters#

domRange

The domain of the iteration.

varUnion[Var, str]

The internal variable that is used for iteration.

iter_typeint

The iteration type.

thread_tagstr

The thread type tag.

spanOptional[Span]

The location of this expression in the source code.

See Also#

te.thread_axis: Create thread axis IterVar. te.reduce_axis: Create reduce axis IterVar.

参数:
class tvm.tir.LE(a, b, span=None)[源代码]#

LE node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.LT(a, b, span=None)[源代码]#

LT node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Layout[源代码]#

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

See Also#

layout : Declare a layout

Methods:

factor_of(axis)

Get the factor size of the subordinate axis.

index_of(axis)

Get the index of an axis

factor_of(axis)[源代码]#

Get the factor size of the subordinate axis.

Parameters#

axisstr

The axis name, need to be [a-z,A-Z]

Returns#

factorint

the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout.

index_of(axis)[源代码]#

Get the index of an axis

Parameters#

axisstr

The axis name, need to be [a-z,A-Z]

Returns#

indexint

The index of the axis, -1 if not found.

class tvm.tir.Let(var, value, body, span=None)[源代码]#

Let node.

Parameters#

varVar

The variable in the binding.

valuePrimExpr

The value in to be bound.

bodyPrimExpr

The body expression.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.LetStmt(var, value, body, span=None)[源代码]#

LetStmt node.

Parameters#

varVar

The variable in the binding.

valuePrimExpr

The value in to be bound.

bodyStmt

The body statement.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.MatchBufferRegion(buffer, source)[源代码]#

MatchBufferRegion node.

Parameters#

bufferBuffer

The target buffer

sourceBufferRegion

The region of source buffer

参数:
class tvm.tir.Max(a, b, span=None)[源代码]#

Max node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Min(a, b, span=None)[源代码]#

Min node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Mod(a, b, span=None)[源代码]#

Mod node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Mul(a, b, span=None)[源代码]#

Mul node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.NE(a, b, span=None)[源代码]#

NE node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Not(a, span=None)[源代码]#

Not node.

Parameters#

aPrimExpr

The input value

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Or(a, b, span=None)[源代码]#

Or node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Prefetch(buffer, bounds, span=None)[源代码]#

Prefetch node.

Parameters#

bufferBuffer

The buffer to be prefetched.

boundsList[Range]

The bounds to be prefetched.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)[源代码]#

A function declaration expression.

Parameters#

params: List[Union[tvm.tir.Var, tvm.tir.Buffer]]

List of input parameters to the function.

body: tvm.tir.Stmt

The body of the function.

ret_type: tvm.ir.Type

The return type annotation of the function.

buffer_mapMap[tvm.tir.Var, tvm.tir.Buffer]

The buffer binding map.

attrs: Optional[tvm.Attrs]

Attributes of the function, can be None

spanOptional[Span]

The location of this itervar in the source code.

Methods:

specialize(param_map)

Specialize parameters of PrimFunc

with_body(new_body[, span])

Create a new PrimFunc with the same set signatures but a new body.

specialize(param_map)[源代码]#

Specialize parameters of PrimFunc

Parameters#

param_mapMapping[Var, Union[PrimExpr, Buffer]]

The mapping from function params to the instance

Examples#

We can define a Meta TIR function with symbolic shape:

@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
    A = T.match_buffer(a, (m, n), "float32")
    B = T.match_buffer(b, (m, n), "float32")

    for i, j in T.grid(m, n):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

The specialized function:

@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    B = T.match_buffer(b, (16, 16), "float32")

    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Returns#

funcPrimFunc

The new function with parameter specialized

参数:

param_map (Mapping[Var, PrimExpr | Buffer])

with_body(new_body, span=None)[源代码]#

Create a new PrimFunc with the same set signatures but a new body.

Parameters#

new_bodyStmt

The new body.

spanOptional[Span]

The location of this itervar in the source code.

Returns#

new_funcPrimFunc

The created new function.

参数:

span (Span | None)

class tvm.tir.ProducerLoad(producer, indices, span=None)[源代码]#

Producer load node.

Parameters#

producerDataProducer

The buffer to be loaded.

indicesList[PrimExpr]

The buffer indices.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.ProducerRealize(producer, bounds, condition, body, storage_scope='', span=None)[源代码]#

ProducerRealize node.

Parameters#

producerDataProducer

The data producer.

boundsList[Range]

The bound of realize

conditionPrimExpr

The realize condition.

bodyStmt

The realize body

storage_scopestr

The storage scope associated with this realization

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.ProducerStore(producer, value, indices, span=None)[源代码]#

ProducerStore node.

Parameters#

producerDataProducer

The data producer.

valuePrimExpr

The value to be stored.

indiceslist of Expr

The index arguments of the store.

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.Ramp(base, stride, lanes, span=None)[源代码]#

Ramp node.

Parameters#

basePrimExpr

The base expression.

stridePrimExpr

The stride of the ramp.

lanesPrimExpr

The lanes of the expression.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Reduce(combiner, src, rdom, condition, value_index, init=None, span=None)[源代码]#

Reduce node.

Parameters#

combinerCommReducer

The combiner.

srclist of Expr

The source expression.

rdomlist of IterVar

The iteration domain

conditionPrimExpr

The reduce condition.

value_indexint

The value index.

initlist of Expr

The initial value for output. This can be an int, float or ProducerLoad

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Schedule(mod, *, seed=None, debug_mask='none', error_render_level='detail', enable_check=True)[源代码]#

The user-facing schedule class

A schedule is a set of transformations that change the order of computation but preserve the semantics of computation. Some example of schedules: 1) Split a loop into two; 2) Reorder two loops; 3) Inline the computation of a specific buffer into its consumer

The schedule class stores auxiliary information to schedule correctly and efficiently.

Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html

Methods:

__init__(mod, *[, seed, debug_mask, ...])

Construct a TensorIR schedule class from an IRModule

_create_non_traced(mod, *[, seed, ...])

Construct a non-traced TensorIR schedule class from an IRModule.

add_unit_loop(block_or_loop)

Create a new unit loop on top of the specific block or loop.

annotate(block_or_loop, ann_key, ann_val)

Annotate a block/loop with a key value pair

bind(loop, thread_axis)

Bind the input loop to the given thread axis.

blockize(target[, preserve_unit_iters])

Convert multiple blocks or the subtree rooted at a specific loop into a block.

cache_index(block, storage_scope[, cse_thresh])

Create a block to cache precomputed index for later use.

cache_inplace(block, read_buffer_index, ...)

Create blocks that reads & write a buffer region into a cache block.

cache_read(block, read_buffer_index, ...[, ...])

Create a block that reads a buffer region into a read cache.

cache_write(block, write_buffer_index, ...)

Create a block that reads a buffer region into a write cache.

can_decompose_padding(block, loop)

Check whether the block match padding pattern and can be decomposed.

compute_at(block, loop[, ...])

Compute-At.

compute_inline(block)

Inline a block into its consumer(s).

copy()

Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed

decompose_padding(block, loop)

Decompose a block of padding computation pattern into two separate blocks.

decompose_reduction(block, loop)

Decompose a reduction block into two separate blocks.

enter_postproc()

A no-op that marks the start of postprocessing phase of scheduling

fork_seed()

Returns a forked random state as seed for new schedules

fuse(*loops[, preserve_unit_iters])

Fuse a list of consecutive loops into one.

get(rand_var_or_sref)

Returns: - the corresponding Block that a BlockRV evaluates to; - the corresponding For that a LoopRV evaluates to; - the corresponding integer that a ExprRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to;

get_block(name[, func_name])

Retrieve a block in a specific function with its name

get_child_blocks(block_or_loop)

Get the leaf blocks of a specific block/loop

get_consumers(block)

Get the consumers of a specific block

get_loops(block)

Get the parent loops of the block in its scope, from outer to inner

get_output_blocks(scope_block)

Get the list of output blocks within the given scope An output block is a block which has atleast one buffer being written to, but is not allocated within the PrimFunc

get_producers(block)

Get the producers of a specific block

get_sref(rand_var_or_stmt)

Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block 4) For

loop_partition(loop, factors[, ...])

Partition a loop into a list of consecutive loops.

merge(*loops)

Merge a list of loops into one.

pad_einsum(block, padding)

Pad the computation of Einsum.

parallel(loop)

Parallelize the input loop.

reindex(block, buffer)

Create a block that read/write a buffer region into a read/write cache with reindexing.

reindex_cache_read(block, read_buffer_index, ...)

Create a block that reads a buffer region into a read cache using customized indices specified by index map.

reindex_cache_write(block, ...)

Create a block that reads a buffer region into a write cache using customized indices specified by index map.

remove_rv(rand_var)

Remove a random variable from the symbol table

reorder(*ordered_loops)

Reorder a list of loops.

reorder_block_iter_var(block, new_order)

Reorder the itervars inside a given block.

reverse_compute_at(block, loop[, ...])

Reverse-Compute-At.

reverse_compute_inline(block)

Inline a block into its only producer.

rfactor(loop, factor_axis)

Factorize an associative reduction block by the specified loop.

rolling_buffer(block, write_buffer_index)

Compute the target buffer via rolling buffering, select the outermost rollable axis with a positive bound overlap that appears in the block's ancestor loops as rolling axis, fold and circularize the buffer along the rolling dimension, append block predicate to avoid recomputing overlapping elements.

sample_categorical(candidates, probs[, decision])

Sample an integer given the probability distribution

sample_compute_location(block[, decision])

Sample a compute-at location of the given block

sample_partitioned_tile(loop, n[, ...])

Sample the factors to a partitioned tile for a specific loop

sample_perfect_tile(loop, n[, ...])

Sample the factors to perfect tile a specific loop

seed(seed)

Seed the randomness

set_axis_separator(block, buffer, ...)

Set the axis separator of a buffer, where the buffer is specified by a block and a read or write index.

set_scope(block, buffer_index, storage_scope)

Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index.

show(*args, **kwargs)

A sugar for print highlighted TVM script.

split(loop, factors[, preserve_unit_iters, ...])

Split a loop into a list of consecutive loops.

storage_align(block, buffer_index, axis, ...)

Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for some k.

tensorize(block_or_loop, tensor_intrin[, ...])

Tensorize the computation enclosed by loop with the tensor intrinsic.

transform_block_layout(block, index_map)

Apply a transformation represented by IndexMap to block

transform_layout(block, buffer, index_map[, ...])

Apply a transformation represented by IndexMap to buffer

unannotate(block_or_loop, ann_key)

Unannotate a block/loop's annotation with key ann_key

unroll(loop)

Unroll the input loop.

unsafe_hide_buffer_access(block, buf_type, ...)

Hide some buffer access in a given block.

unsafe_set_dtype(block, buffer_index, dtype)

Set the data type of a buffer, where the buffer is specified by the a block and write-index.

vectorize(loop)

Vectorize the input loop.

work_on(func_name)

Instruct the schedule to work on a function in the IRModule.

Attributes:

func_working_on

Returns the GlobalVar of the func that the schedule is currently working on

mod

Returns the AST of the module being scheduled

state

Returns the ScheduleState in the current schedule class

trace

Returns the internally maintained trace of scheduling program execution

参数:
__init__(mod, *, seed=None, debug_mask='none', error_render_level='detail', enable_check=True)[源代码]#

Construct a TensorIR schedule class from an IRModule

Parameters#

modUnion[PrimFunc, IRModule]

The IRModule or PrimFunc to be scheduled

seed: Optional[int]

The seed value for schedule’s random state Note that None and -1 means use device random, otherwise only integer between 1 and 2147483647 is allowed.

debug_maskUnion[str, int]

Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of debug_mask: 1) “all” - Turn on all the checks 2) “none” - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask

error_render_levelstr = “detail”

The level of error rendering. Choices: “detail”, “fast”, “none”. - “detail”: Render a detailed error message, with the TIR and error locations printed - “fast: Show a simple error message without rendering or string manipulation - “none”: Do not show any error message.

enable_checkbool = True

The default schedule checks are too strict and might prevent us performing some valid schedules. enable_check is an argument to control whether we enable prerequisite checks for some schedule primitives or not: - true: perform prerequisite check before applying some schedules. - false: do not perform some check before applying schedules, but still raise error if schedule fails.

It’s user duty to guarantee schedule correctness if enable_check is set to False.

Note#

The checks performed includes: 1) VerifySRefTree 2) VerifyCachedFlags

参数:
返回类型:

None

static _create_non_traced(mod, *, seed=None, debug_mask='none', error_render_level='detail', enable_check=True)[源代码]#

Construct a non-traced TensorIR schedule class from an IRModule.

参数:
返回类型:

Schedule

add_unit_loop(block_or_loop)[源代码]#

Create a new unit loop on top of the specific block or loop.

Parameters#

block_or_loopUnion[LoopRV, BlockRV]

The block above which the new loop is created

Returns#

new_loopLoopRV

The new unit loop

Examples#

Before add_unit_loop, in TensorIR, the IR is:

@T.prim_func
def before_add_unit_loop(
    A: T.Buffer((), "int32"),
    B: T.Buffer((), "int32"),
    C: T.Buffer((), "int32"),
) -> None:
    with T.block("C"):
        vi = T.axis.spatial(1, 0)
        C[()] = A[()] + B[()]

Create the schedule and do add-unit-loop:

sch = tir.Schedule(before_add_unit_loop)
sch.add_unit_loop(sch.get_block("C"))
print(sch.mod["main"].script())

After applying add-unit-loop, the IR becomes:

@T.prim_func
def after_add_unit_loop(
    A: T.Buffer((), "int32"),
    B: T.Buffer((), "int32"),
    C: T.Buffer((), "int32"),
) -> None:
    for u in T.serial(1):
        with T.block("C"):
            vi = T.axis.spatial(1, 0)
            C[()] = A[()] + B[()]
参数:

block_or_loop (LoopRV | BlockRV)

返回类型:

LoopRV

annotate(block_or_loop, ann_key, ann_val)[源代码]#

Annotate a block/loop with a key value pair

Parameters#

block_or_loop: Union[BlockRV, LoopRV]

The block/loop to be annotated

ann_keystr

The annotation key

ann_valAnnotationValueT

The annotation value

Examples#

Before annotate, in TensorIR, the IR is:

@T.prim_func
def before_annotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do annotate:

sch = tir.Schedule(before_annotate)
sch.annotate(sch.get_block("B"), "ann_key", "ann_value")
print(sch.mod["main"].script())

After applying annotate, the IR becomes:

@T.prim_func
def after_annotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.block_attr({"ann_key", "ann_value"})
            B[vi, vj] = A[vi, vj] * 2.0
参数:
返回类型:

None

bind(loop, thread_axis)[源代码]#

Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in should have stage-pipeline property 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine bindings 3) For each block under the loop, if the thread axis starts with “threadIdx`, the loop can only be contained in data-parallel block iter and reduction block iters’ bindings. Otherwise the loop can only be contained in data-parallel block iters’ bindings

Parameters#

loopLoopRV

The loop to be bound to the thread axis

thread_axisstr

The thread axis to be bound to the loop. Possible candidates: - blockIdx.x/y/z - threadIdx.x/y/z - vthread.x/y/z - vthread (It is a legacy behavior that will be deprecated. Please use vthread.x/y/z instead.)

Examples#

Before bind, in TensorIR, the IR is:

@T.prim_func
def before_bind(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do bind:

sch = tir.Schedule(before_bind)
i, j = sch.get_loops(sch.get_block("B"))
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")

After applying bind, the IR becomes:

@T.prim_func
def after_bind(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.thread_binding(0, 128, thread = "blockIdx.x"):
        for j in T.thread_binding(0, 128, thread = "threadIdx.x"):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
参数:
  • loop (LoopRV)

  • thread_axis (str)

返回类型:

None

blockize(target, preserve_unit_iters=True)[源代码]#

Convert multiple blocks or the subtree rooted at a specific loop into a block.

Parameters#

targetLoopRV or List[BlockRV]

The root of the subtree or the specified blocks.

preserve_unit_itersbool

Whether or not to preserve unit iterators in block bindings

Returns#

resultBlockRV

The new block.

Examples#

Before blockize, in TensorIR, the IR is:

@T.prim_func
def before_blockize(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32")
) -> None:
    for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
        with T.block("B"):
            vi = T.axis.spatial(128, i_0 * 16 + i_1)
            vj = T.axis.spatial(128, j_0 * 16 + j_1)
            T.reads(A[vi, vj])
            T.writes(B[vi, vj])
            B[vi, vj] = A[vi, vj] * T.float32(2)

Create the schedule and do set_scope:

sch = tir.Schedule(before_blockize)
B = sch.get_block("B")
_, _, i1, _ = sch.get_loops(B)
sch.blockize(i1)
print(sch.mod["main"].script())

After applying blockize, the IR becomes:

@T.prim_func
def after_blockize(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32")
)-> None:
    for i_0, j_0 in T.grid(8, 8):
        with T.block("B_o"):
            vio, vjo = T.axis.remap("SS", [i_0, j_0])
            T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
            T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
            for i_1, j_1 in T.grid(16, 16):
                with T.block("B"):
                    vi, vj = T.axis.remap("SS", [i_1, j_1])
                    T.reads(A[vio * 16 + vi, vjo * 16 + vj])
                    T.writes(B[vio * 16 + vi, vjo * 16 + vj])
                    B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj]                                                                   * T.float32(2)

Note#

blockize requires there is exactly one block under the given loop and the bindings of the block are divisible by the subspace represented by the loops starting at the given loop.

参数:
  • target (LoopRV | List[BlockRV])

  • preserve_unit_iters (bool)

返回类型:

BlockRV

cache_index(block, storage_scope, cse_thresh=0)[源代码]#

Create a block to cache precomputed index for later use. if there is no index computation, keep unchanged.

Parameters#

blockUnion[BlockRV, str]

The target block operates on the target buffer.

storage_scope: str

The storage scope of cached block.

cse_thresh: int

The repeat threshold that determines a common sub expr, default 0 means cache all index computation.

Returns#

cached_blocksList[BlockRV]

The blocks of the stage writing the cache buffers

Examples#

Before cache_inplace, in TensorIR, the IR is:

@T.prim_func
def resize(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (1, 3, 40, 40))
    B = T.match_buffer(b, (1, 3, 80, 80))
    for i0, i1, i2, i3 in T.grid(1, 3, 80, 80):
        with T.block("A"):
            n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3])
            B[n, c, vi, vj] = A[n, c, vi//4 + vj//4, vj//2]

Create the schedule and cache_index:

sch = tir.Schedule(resize)
block_a = sch.get_block("A")
sch.cache_index(block_a, "global", 1)
print(sch.mod["main"].script())

After applying cache_index, the IR becomes:

@T.prim_func
def resize_cache_index(
    A: T.Buffer((1, 3, 40, 40), "float32"), B: T.Buffer((1, 3, 80, 80), "float32")
) -> None:
    index_var_0 = T.alloc_buffer([80, 80], dtype="int32", strides=[1])
    index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1])
    for ax0, ax1 in T.grid(80, 80):
        with T.block("index_0"):
            v0 = T.axis.spatial(80, ax0)
            v1 = T.axis.spatial(80, ax1)
            T.reads()
            T.writes(index_var_0[v0, v1])
            index_var_0[v0, v1] = v0 // 4 + v1 // 4
    for ax0 in T.serial(80):
        with T.block("index_1"):
            v0 = T.axis.spatial(80, ax0)
            T.reads()
            T.writes(index_var_1[v0])
            index_var_1[v0] = v0 // 2
    for i0, i1, i2, i3 in T.grid(1, 3, 80, 80):
        with T.block("A"):
            n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(A[n, c, vi // 4 + vj // 4, vj // 2])
            T.writes(B[n, c, vi, vj])
            B[n, c, vi, vj] = A[n, c, index_var_0[vi, vj], index_var_1[vj]]
参数:
  • block (BlockRV | str)

  • storage_scope (str)

  • cse_thresh (int)

返回类型:

List[BlockRV]

cache_inplace(block, read_buffer_index, storage_scope)[源代码]#

Create blocks that reads & write a buffer region into a cache block. It requires the target block both read & write the target buffer. Mainly for inplace operation.

Parameters#

blockUnion[BlockRV, str]

The target block operates on the target buffer.

read_buffer_index: int

The index of the buffer in block’s read region, the unique name of a read buffer in the block, or a Buffer object that is within the blocks read region.

storage_scope: str

The target storage scope.

Returns#

cached_blocksList[BlockRV]

The blocks of the cache stage, read cache first, write cache second

Examples#

Before cache_inplace, in TensorIR, the IR is:

@T.prim_func
def before_cache_inplace(data_io: T.Buffer((64), "int32")):
    for i0 in T.serial(1):
        with T.block("A"):
            T.reads(data_io[:64])
            T.writes(data_io[:64])
            T.evaluate(T.call_extern("call_impl", data_io.data, dtype=""))

Create the schedule and cache_inplace:

sch = tir.Schedule(before_cache_inplace)
block_a = sch.get_block("A")
sch.cache_inplace(block_a, 0, "local")
print(sch.mod["main"].script())

After applying cache_inplace, the IR becomes:

@T.prim_func
def cache_inplace(data_io: T.Buffer(64, "int32")) -> None:
    data_io_local = T.alloc_buffer([64], dtype="int32", scope="local")
    for i0 in T.serial(1):
        for ax0 in T.serial(64):
            with T.block("data_io_local"):
                v0 = T.axis.spatial(64, ax0)
                T.reads(data_io[v0])
                T.writes(data_io_local[v0])
                data_io_local[v0] = data_io[v0]
        with T.block("A"):
            T.reads(data_io_local[0 : 64])
            T.writes(data_io_local[0 : 64])
            T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype=""))
        for ax0 in T.serial(64):
            with T.block("data_io_local"):
                v0 = T.axis.spatial(64, ax0)
                T.reads(data_io_local[v0])
                T.writes(data_io[v0])
                data_io[v0] = data_io_local[v0]
参数:
返回类型:

List[BlockRV]

cache_read(block, read_buffer_index, storage_scope, consumer_blocks=None)[源代码]#

Create a block that reads a buffer region into a read cache. It requires:

  1. There is at most one block who write the buffer in the scope.

  2. The scope block have stage-pipeline property.

Parameters#

blockUnion[BlockRV, str]

The consumer block of the target buffer.

buffer: Union[int, str, Buffer]

The index of the buffer in block’s read region, the unique name of a read buffer in the block, or a Buffer object that is within the blocks read region.

storage_scope: str

The target storage scope.

consumer_blocks: Optional[List[Union[BlockRV, str]]]

An optional list of consumers that should read from the cache. If not specified, all consumers will use the cache.

Returns#

cached_blockBlockRV

The block of the cache stage

Examples#

Before cache_read, in TensorIR, the IR is:

@T.prim_func
def before_cache_read(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and cache_read:

sch = tir.Schedule(before_cache_read)
block_b = sch.get_block("B")
sch.cache_read(block_b, 0, "local")
print(sch.mod["main"].script())

After applying cache_read, the IR becomes:

@T.prim_func
def after_cache_read(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    A_local = T.alloc_buffer((128, 128), scope="local")
    for i, j in T.grid(128, 128):
        with T.block("A_local"):
            vi, vj = T.axis.remap("SS", [i, j])
            A_local[vi, vj] = A[vi, vj]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A_local[vi, vj] * 2.0
参数:
  • block (BlockRV | str)

  • read_buffer_index (int | str | Buffer)

  • storage_scope (str)

  • consumer_blocks (List[BlockRV | str] | None)

返回类型:

BlockRV

cache_write(block, write_buffer_index, storage_scope, consumer_blocks=None)[源代码]#

Create a block that reads a buffer region into a write cache. It requires:

  1. There is only one block who write the buffer in the scope.

  2. The scope block have stage-pipeline property.

Parameters#

blockUnion[BlockRV, str]

The producer block of the target buffer.

write_buffer_index: int

The index of the buffer in block’s write region, the unique name of a write buffer in the block, or a Buffer object that is within the blocks write region.

storage_scope: str

The target storage scope.

consumer_blocks: Optional[List[Union[BlockRV, str]]]

An optional list of consumers that should read directly from the cache. If not specified, all consumers will read from the original buffer.

Returns#

cached_blockBlockRV

The block of the cache stage

Examples#

Before cache_write, in TensorIR, the IR is:

@T.prim_func
def before_cache_write(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and cache_write:

sch = tir.Schedule(before_cache_write)
block_b = sch.get_block("B")
sch.cache_write(block_b, 0, "local")
print(sch.mod["main"].script())

After applying cache_write, the IR becomes:

@T.prim_func
def after_cache_write(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    B_local = T.alloc_buffer((128, 128), scope="local")
    for i, j in T.grid(128, 128):
        with T.block("A_local"):
            vi, vj = T.axis.remap("SS", [i, j])
            B_local[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = B_local[vi, vj]
参数:
  • block (BlockRV | str)

  • write_buffer_index (int | str | Buffer)

  • storage_scope (str)

  • consumer_blocks (List[BlockRV | str] | None)

返回类型:

BlockRV

can_decompose_padding(block, loop)[源代码]#

Check whether the block match padding pattern and can be decomposed.

参数:
  • block (BlockRV | str)

  • loop (LoopRV)

返回类型:

bool

compute_at(block, loop, preserve_unit_loops=False, index=-1)[源代码]#

Compute-At. Move a producer block under the specific loop, and regenerate the loops induced by the block so that the buffer region produced by the producer block could cover those regions consumed by its consumer blocks under the given loop. It requires:

  1. block and loop are under the same scope, loop is not the ancestor of block

  2. The scope block has stage-pipeline property

3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow condition. i.e. all the blocks in the scope block’s subtree must be either complete block or reduction block

4) The block is not an output block with regard to the scope block, i.e. the buffers written by the block are allocated under the scope block

  1. All the consumers of the block are under the given loop

Parameters#

blockUnion[BlockRV, str]

The block to be moved

loop: LoopRV

The loop where the block to be moved under

preserve_unit_loops: bool

Whether to keep the trivial loops whose extents are 1

index: int

The block index of the loop body subtree blocks: - index = -1 means inserted into the last possible insertion point; - index = -2 means inserted into the first possible insertion point; - Otherwise, index is a nonnegative number that indicates the insertion point

Examples#

Before compute-at, in TensorIR, the IR is:

@T.prim_func
def before_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do compute-at:

sch = tir.Schedule(before_compute_at)
block = sch.get_block("B")
loop, _ = sch.get_loops(sch.get_block("C"))
sch.compute_at(block, loop, preserve_unit_loops=False)
print(sch.mod["main"].script())

After applying compute-at, the IR becomes:

@T.prim_func
def after_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i in T.serial(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
        for j in T.serial(0, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = B[vi, vj] + 1.0
参数:
  • block (BlockRV | str)

  • loop (LoopRV)

  • preserve_unit_loops (bool)

  • index (int)

返回类型:

None

compute_inline(block)[源代码]#

Inline a block into its consumer(s). It requires:

  1. The block is a complete non-root block, which only produces one buffer

  2. The block must not be the only leaf in the scope.

  3. The body of the block must be a BufferStore statement in the form of, A[i, j, k, ...] = ... where the indices of the LHS are all distinct atomic variables, and no variables other than those indexing variables are allowed in the statement.

Parameters#

blockUnion[BlockRV, str]

The block to be inlined to its consumer(s)

Examples#

Before compute-inline, in TensorIR, the IR is:

@T.prim_func
def before_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do compute-inline:

sch = tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
print(sch.mod["main"].script())

After applying compute-inline, the IR becomes:

@T.prim_func
def after_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
参数:

block (BlockRV | str)

返回类型:

None

copy()[源代码]#

Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed

Returns#

copySchedule

A new copy of the schedule

返回类型:

Schedule

decompose_padding(block, loop)[源代码]#

Decompose a block of padding computation pattern into two separate blocks.

  1. The block which fill const pad values into full write region;

  2. The block which fill in-bound values into region where pad predicate is true.

The pad value filling block is inserted right before the given loop.

The schedule primitive requires:

  1. The input block is a complete block.

  2. The input loop is the ancestor of the block.

  3. The input block is a block which match padding pattern.

Parameters#

blockUnion[BlockRV, str]

The padding block to be decomposed.

loopLoopRV

The loop above which the pad value filling block is inserted before.

Returns#

pad_value_blockBlockRV

The block filling const pad values.

Examples#

Before decompose-padding, in TensorIR, the IR is:

@T.prim_func
def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")):
    for i in range(140):
        with T.block("block"):
            vi = T.axis.remap("S", [i])
            y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32")

Create the schedule and do decompose-padding with specified loop:

sch = tir.Schedule(before_decompose, debug_mask="all")
block = sch.get_block("block")
sch.decompose_padding(block, sch.get_loops(block)[0])
print(sch.mod["main].script())

After applying decompose-padding, the IR becomes:

@T.prim_func
def after_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")):
    for i in T.serial(140):
        with T.block("block_pad_const"):
            vi = T.axis.spatial(140, i)
            y[vi] = 0
    for i in T.serial(128):
        with T.block("block"):
            vi = T.axis.spatial(128, i)
            y[vi + 6] = x[vi]
参数:
  • block (BlockRV | str)

  • loop (LoopRV)

返回类型:

BlockRV

decompose_reduction(block, loop)[源代码]#

Decompose a reduction block into two separate blocks.

  1. The init block, which is translated from the init statement of the reduction block;

  2. The update block, which is the original block without init statement.

The init block is inserted right before the given loop.

The schedule primitive requires:

  1. The input block is a reduction block.

  2. The input loop is the ancestor of the block.

  3. The input loop is not lower than all the loops related to reduce block var.

Parameters#

blockUnion[BlockRV, str]

The reduction block to be decomposed

loopLoopRV

The loop above which the init block is inserted before.

Returns#

init_blockBlockRV

The init block

Examples#

Before decompose-reduction, in TensorIR, the IR is:

@T.prim_func
def before_decompose(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    for i, j, k in tir.grid(128, 128, 128):
        with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
            with tir.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Create the schedule and do decompose-reduction with specified loop:

sch = tir.Schedule(before_decompose)
C = sch.get_block("C")
i, j, k = sch.get_loops(C)
sch.decompose_reduction(C, i)
print(sch.mod["main"].script())

After applying decompose-reduction, the IR becomes:

@T.prim_func
def after_decompose(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    for i in tir.serial(128):
        for j in tir.serial(128):
            with tir.block([128, 128]) as [vi, vj]:
                C[vi, vj] = 0.0
    for i, j, k in tir.grid(128, 128, 128):
        with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
参数:
  • block (BlockRV | str)

  • loop (LoopRV)

返回类型:

BlockRV

enter_postproc()[源代码]#

A no-op that marks the start of postprocessing phase of scheduling

返回类型:

None

fork_seed()[源代码]#

Returns a forked random state as seed for new schedules

Returns#

seedint

The forked random state, not the same as the current random state

返回类型:

int

fuse(*loops, preserve_unit_iters=True)[源代码]#

Fuse a list of consecutive loops into one. It requires: 1) The loops can’t have annotations or thread bindings. 2) The (i+1)-th loop must be the only child of the i-th loop. 3) All loops must start with 0. 4) The domain of a loop to be fused cannot depend on another loop to be fused.

Parameters#

*loopsList[LoopRV]

The loops to be fused

Returns#

fused_loopLoopRV

The new loop after fusion

Examples#

Before applying fuse, in TensorIR, the IR is:

@T.prim_func
def before_fuse(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do fuse:

sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(sch.mod["main"].script())

After applying fuse, the IR becomes:

@T.prim_func
def after_fuse(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # the 2 loops are fused into 1
    for i_j_fused in T.serial(0, 16384):
        with T.block("B"):
            vi = T.axis.S(128, T.floordiv(i_j_fused, 128))
            vj = T.axis.S(128, T.floormod(i_j_fused, 128))
            B[vi, vj] = A[vi, vj] * 2.0
参数:
  • loops (List[LoopRV])

  • preserve_unit_iters (bool)

返回类型:

LoopRV

get(rand_var_or_sref)[源代码]#

Returns: - the corresponding Block that a BlockRV evaluates to; - the corresponding For that a LoopRV evaluates to; - the corresponding integer that a ExprRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to;

Parameters#

rand_var_or_srefUnion[ExprRV, BlockRV, LoopRV, StmtSRef]

The random variable / sref to be evaluated

Returns#

resultOptional[Union[int, Block, For]]

The corresponding result

参数:

rand_var_or_sref (PrimExpr | BlockRV | LoopRV | StmtSRef)

返回类型:

int | Block | For | None

get_block(name, func_name=None)[源代码]#

Retrieve a block in a specific function with its name

By default, if func_name is not specified, the schedule will search for the block in the function that is currently being “worked on”. To switch the function to be worked on, use work_on before calling this method.

Parameters#

namestr

The name of the block

func_nameOptional[str] = None

The name of the function

Returns#

blockBlockRV

The block retrieved IndexError is raised if 0 or multiple blocks exist with the specific name.

参数:
  • name (str)

  • func_name (str | None)

返回类型:

BlockRV

get_child_blocks(block_or_loop)[源代码]#

Get the leaf blocks of a specific block/loop

Parameters#

block_or_loopUnion[BlockRV, LoopRV]

The query block/loop

Returns#

blocksList[LoopRV]

A list of leaf blocks inside a specific block/loop

参数:

block_or_loop (BlockRV | LoopRV)

返回类型:

List[BlockRV]

get_consumers(block)[源代码]#

Get the consumers of a specific block

Parameters#

blockUnion[BlockRV, str]

The block in the query

Returns#

consumersList[BlockRV]

A list of consumers of the given block

参数:

block (BlockRV | str)

返回类型:

List[BlockRV]

get_loops(block)[源代码]#

Get the parent loops of the block in its scope, from outer to inner

Parameters#

blockUnion[BlockRV, str]

The query block

Returns#

loopsList[LoopRV]

A list of loops above the given block in its scope, from outer to inner

参数:

block (BlockRV | str)

返回类型:

List[LoopRV]

get_output_blocks(scope_block)[源代码]#

Get the list of output blocks within the given scope An output block is a block which has atleast one buffer being written to, but is not allocated within the PrimFunc

Parameters#

scope_blockUnion[BlockRV, str],

The scope block from which output blocks are collected

Returns#

output_blocksList[BlockRV]

A list of all blocks that write to some output buffer

参数:

scope_block (BlockRV | str)

返回类型:

List[BlockRV]

get_producers(block)[源代码]#

Get the producers of a specific block

Parameters#

blockUnion[BlockRV, str]

The block in the query

Returns#

producersList[BlockRV]

A list of producers of the given block

参数:

block (BlockRV | str)

返回类型:

List[BlockRV]

get_sref(rand_var_or_stmt)[源代码]#

Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block 4) For

Parameters#

rand_var_or_stmtUnion[BlockRV, LoopRV, Block, For]

The random variable / sref to be evaluated

Returns#

resultOptional[StmtSRef]

The corresponding result

参数:

rand_var_or_stmt (BlockRV | LoopRV | Block | For)

返回类型:

StmtSRef | None

loop_partition(loop, factors, preserve_unit_iters=True)[源代码]#

Partition a loop into a list of consecutive loops. It requires: 1) The loop can’t have annotation or thread binding. Predicates may be added to ensure the total loop numbers keeps unchanged. In factors, at most one of the factors can be None, which will be automatically inferred.

Parameters#

loopLoopRV

The loop to be partition

factors: List[Union[int, ExprRV, None]]

The partitioning factors Potential inputs are: - None - ExprRV - Positive constant integers

preserve_unit_itersbool

Whether or not to preserve unit iterators in block bindings

Returns#

partition_loopsList[LoopRV]

The new loops after partition

Examples#

Before partition, in TensorIR, the IR is:

@T.prim_func
def before_partition(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do partition:

sch = tir.Schedule(before_partition)
i, j = sch.get_loops(sch.get_block("B"))
sch.partition(i, factors=[2, 64])
print(sch.mod["main"].script())

After applying partition, the IR becomes:

def after_partition(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # the original loop is partition into 3 loops
    with T.block("root"):
        T.reads()
        T.writes()
        with T.block("B_i_common"):
            T.reads()
            T.writes()
            with T.block("B_i0_partition"):
                T.reads()
                T.writes()
                for i0, j in T.grid(2, 128):
                    with T.block("B_i0"):
                        vi, vj = T.axis.remap("SS", [i0, j])
                        T.reads(A[0:2, 0:128])
                        T.writes(B[0:2, 0:128])
                        B[vi, vj] = A[vi, vj] * T.float32(2)
            with T.block("B_i1_partition"):
                T.reads()
                T.writes()
                for i1 in range(2, 66):
                    for j in range(128):
                        with T.block("B_i1"):
                            vi, vj = T.axis.remap("SS", [i1, j])
                            T.reads(A[2:66, 0:128])
                            T.writes(B[2:66, 0:128])
                            B[vi, vj] = A[vi, vj] * T.float32(2)
            with T.block("B_partition_2"):
                T.reads()
                T.writes()
                for i2 in range(66, 128):
                    for j in range(128):
                        with T.block("B_i2"):
                            vi, vj = T.axis.remap("SS", [i2, j])
                            T.reads(A[66:128, 0:128])
                            T.writes(B[66:128, 0:128])
                            B[vi, vj] = A[vi, vj] * T.float32(2)
参数:
返回类型:

List[LoopRV]

merge(*loops)[源代码]#

Merge a list of loops into one. The loops under their LCA requires: 1) Under the same scope. 2) Can’t have annotations or thread bindings. 3) Start with 0 and have same extent and same nesting depth. 4) From target loop to their LCA, The inner loop must be the only child of the outer loop.

Parameters#

*loopsList[LoopRV]

The loops to be merged

Returns#

fused_loopLoopRV

The new loop after merge

Examples#

Before applying merge, in TensorIR, the IR is:

@T.prim_func
def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do fuse:

sch = tir.Schedule(before_fuse)
i1, _ = sch.get_loops(sch.get_block("B"))
i2, _ = sch.get_loops(sch.get_block("C"))
sch.merge(i1, i2)
print(sch.mod["main"].script())

After applying fuse, the IR becomes:

@T.prim_func
def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    C = T.match_buffer(c, (128, 128))
    # the 2 loops are merged into 1
    for i_m in range(128):
        for j in range(128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i_m, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj] * T.float32(2)
        for j in range(128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i_m, j])
                T.reads(A[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = A[vi, vj] * T.float32(2)
参数:

loops (List[LoopRV])

返回类型:

LoopRV

pad_einsum(block, padding)[源代码]#

Pad the computation of Einsum.

On a block with trivial binding, this primitive pads the iteration domain of the block by the given padding factors, for example, 127 -> 128, 132 -> 144 when padding factor is 16. Extra producer and consumer padding blocks will be generated to avoid out-of-bound buffer access.

Einsum pattern means all the indices on the buffer access are either by constants (e.g. B[0]) or by variables (e.g. B[i]), but not by composite expressions (e.g. B[i + 1]).

Parameters#

blockUnion[BlockRV, str]

The block that matches the Einsum pattern.

paddingList[int]

The padding for each block iter.

Examples#

Before applying pad-einsum, in TensorIR, the IR is:

@T.prim_func
def before_pad_einsum(
    A: T.Buffer((127, 127), "float32"),
    B: T.Buffer((127, 127), "float32"),
    C: T.Buffer((127, 127), "float32"),
) -> None:
    for i0, i1, i2 in T.grid(127, 127, 127):
        with T.block("C_shared"):
            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
            with T.init():
                C[i, j] = T.float32(0)
            C[i, j] = C[i, j] + A[i, k] * B[k, j]

Create the schedule and do pad-einsum with specified block:

sch = tir.Schedule(before_pad_einsum, debug_mask="all")
block = sch.get_block("C_shared")
sch.pad_einsum(block, [32, 32, 32])
print(sch.mod["main"].script())

After applying decompose-padding, the IR becomes:

@T.prim_func
def main(
    A: T.Buffer((127, 127), "float32"),
    B: T.Buffer((127, 127), "float32"),
    C: T.Buffer((127, 127), "float32"),
):
    # with T.block("root"):
    A_pad = T.alloc_buffer((128, 128))
    B_pad = T.alloc_buffer((128, 128))
    C_pad = T.alloc_buffer((128, 128))
    for i0, i1 in T.grid(128, 128):
        with T.block("A_pad"):
            v0, v1 = T.axis.remap("SS", [i0, i1])
            A_pad[v0, v1] = T.if_then_else(
                v0 < 127 and v1 < 127,
                A[v0, v1],
                T.float32(0),
            )
    for i0, i1 in T.grid(128, 128):
        with T.block("B_pad"):
            v0, v1 = T.axis.remap("SS", [i0, i1])
            B_pad[v0, v1] = T.if_then_else(
                v0 < 127 and v1 < 127,
                B[v0, v1],
                T.float32(0),
            )
    for i0, i1, i2 in T.grid(128, 128, 128):
        with T.block("C_shared"):
            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
            with T.init():
                C_pad[i, j] = T.float32(0)
            C_pad[i, j] = C_pad[i, j] + A_pad[i, k] * B_pad[k, j]
    for i0, i1 in T.grid(127, 127):
        with T.block("C_pad"):
            v0, v1 = T.axis.remap("SS", [i0, i1])
            C[v0, v1] = C_pad[v0, v1]
参数:
返回类型:

None

parallel(loop)[源代码]#

Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipeline property 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine bindings 3) For each block under the loop, the loop can only be contained in data-parallel block iters’ bindings

Parameters#

loopLoopRV

The loop to be parallelized

Examples#

Before parallel, in TensorIR, the IR is:

@T.prim_func
def before_parallel(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do parallel:

sch = tir.Schedule(before_parallel)
i, j = sch.get_loops(sch.get_block("B"))
sch.parallel(i)

After applying parallel, the IR becomes:

@T.prim_func
def after_parallel(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.parallel(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
参数:

loop (LoopRV)

返回类型:

None

reindex(block, buffer)[源代码]#

Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes the buffer. It requires: 1) There is only one block who reads/writes the target buffer 2) There is only one buffer load/store of this buffer in the block

Parameters#

block : Union[BlockRV, str]

The block that accesses the target buffer. If a string, this must uniquely identify a block.

buffer: Union[Tuple[str,int], Buffer, str]

The buffer to be transformed, or a specification of how to identify the buffer to be transformed.

If buffer if a tuple of (str,int), the first item should be either “read” or “write”, and the second item is an index into the block’s read or write regions.

If buffer is a string, it is the name of the buffer, which must exist within the reads/writes of the block. In addition, the reads/writes of the block may not contain more than one buffer with this name.

If buffer is a Buffer object, it must exist within the reads/writes of the block.

Returns#

reindex_blockBlockRV

The block of the reindex stage

Examples#

Before reindex, in TensorIR, the IR is:

@T.prim_func
def before_reindex(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32")
) -> None:
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vj, vi] * 2.0

Create the schedule and do reindex:

sch = tir.Schedule(before_reindex)
block = sch.get_block("B")
sch.reindex(block, ("read", 0))

After applying reindex, the IR becomes:

@T.prim_func
def after_reindex(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32")
) -> None:
    A_reindex = T.alloc_buffer((128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("A_reindex"):
            vi, vj = T.axis.remap("SS", [i, j])
            A_reindex[vi, vj] = A[vj, vi]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A_reindex[vi, vj] * 2.0
参数:
返回类型:

BlockRV

reindex_cache_read(block, read_buffer_index, storage_scope, index_map)[源代码]#

Create a block that reads a buffer region into a read cache using customized indices specified by index map. The read region of the buffer must be a single point.

The cache stage block follows the original order of loops and block itervars in the block. If a block itervar does not appear in the buffer access region, it and its corresponding loop variables will be omitted. User can then use transform_block_layout primitive to reorder the block itervars and surrounding loops of the cache read/write block.

Unlike cache_read, reindex_cache_read only supports single consumer, please use cache_read when there are multiple consumers.

Parameters#

blockBlockRV

The consumer block of the target buffer.

read_buffer_index: int

The index of the buffer in block’s read region.

storage_scope: str

The target storage scope.

index_map: Union[IndexMap, Callable]

User defined indices to access allocated cache buffer, maps from block iter vars.

Returns#

cached_blockBlockRV

The block of the cache stage

Examples#

Before reindex_cache_read, in TensorIR, the IR is:

@T.prim_func
def before_reindex_cache_read(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and reindex_cache_read:

sch = tir.Schedule(before_cache_read)
block_b = sch.get_block("B")
sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi))
print(sch.mod["main"].script())

After applying reindex_cache_read, the IR becomes:

@T.prim_func
def after_reindex_cache_read(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    A_local = T.alloc_buffer((128, 128), scope="local")
    for i, j in T.grid(128, 128):
        with T.block("A_local"):
            vi, vj = T.axis.remap("SS", [i, j])
            A_local[vj, vi] = A[vi, vj]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A_local[vj, vi] * 2.0

See Also#

reindex_cache_write transform_block_layout transform_layout cache_read reindex

参数:
返回类型:

BlockRV

reindex_cache_write(block, write_buffer_index, storage_scope, index_map)[源代码]#

Create a block that reads a buffer region into a write cache using customized indices specified by index map. The write region of the buffer must be a single point.

The cache stage block follows the original order of loops and block itervars in the block. If a block itervar does not appear in the buffer access region, it and its corresponding loop variables will be omitted. User can then use transform_block_layout primitive to reorder the block itervars and surrounding loops of the cache read/write block.

Unlike cache_write, reindex_cache_write only supports single consumer, please use cache_write when there are multiple consumers.

Parameters#

blockUnion[BlockRV, str]

The consumer block of the target buffer.

write_buffer_index: int

The index of the buffer in block’s write region.

storage_scope: str

The target storage scope.

index_map: Union[Callable, IndexMap]

User defined indices to access allocated cache buffer, maps from block iter vars.

consumer_blocks: Optional[List[Union[BlockRV, str]]]

An optional list of consumers that should read directly from the cache. If not specified, all consumers will read from the original buffer.

Returns#

cached_blockBlockRV

The block of the cache stage

Examples#

Before reindex_cache_write, in TensorIR, the IR is:

@T.prim_func
def before_reindex_cache_write(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and reindex_cache_write:

sch = tir.Schedule(before_cache_write)
block_b = sch.get_block("B")
sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj))
print(sch.mod["main"].script())

After applying reindex_cache_write, the IR becomes:

@T.prim_func
def after_cache_write(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (64, 2, 128))
    B_local = T.alloc_buffer((128, 128), scope="local")
    for i, j in T.grid(128, 128):
        with T.block("A_local"):
            vi, vj = T.axis.remap("SS", [i, j])
            B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = B_local[vi % 2, vi // 2, vj]

See Also#

reindex_cache_read transform_block_layout transform_layout cache_write reindex

参数:
返回类型:

BlockRV

remove_rv(rand_var)[源代码]#

Remove a random variable from the symbol table

Parameters#

rand_varUnion[BlockRV, LoopRV, ExprRV]

The random variable to be removed

参数:

rand_var (PrimExpr | BlockRV | LoopRV)

返回类型:

None

reorder(*ordered_loops)[源代码]#

Reorder a list of loops. It doesn’t require the loops to be consecutive. It requires: 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, … , l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between l_1 and l_n (which also indicates they are under the same scope). 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. 3) For every block under the loop nests, its block binding must be affine, and the block variables must be either data parallel or reduction. 4) No duplicated loops are allowed in the arguments.

Parameters#

*ordered_loopsList[LoopRV]

The loops in the new order

Examples#

Before reorder, in TensorIR, the IR is:

@T.prim_func
def before_reorder(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do reorder:

sch = tir.Schedule(before_reorder)
i, j = sch.get_loops(sch.get_block("B"))
sch.reorder(j, i)
print(sch.mod["main"].script())

After applying reorder, the IR becomes:

@T.prim_func
def after_reorder(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # Here j and i are reordered
    for j, i in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
参数:

ordered_loops (List[LoopRV])

返回类型:

None

reorder_block_iter_var(block, new_order)[源代码]#

Reorder the itervars inside a given block.

Parameters#

blockBlockRV

The block to be transformed.

new_orderList[int]

The new block itervar order.

Examples#

Before reorder_block_iter_var, in TensorIR, the IR is:

@T.prim_func
def matmul(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32"),
    C: T.Buffer((128, 128), "float32"),
) -> None:
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Create the schedule and do reorder_block_iter_var:

sch = tir.Schedule(matmul)
C = sch.get_block("C")
sch.reorder_block_iter_var(C, [2, 1, 0])

After applying reorder_block_iter_var, the IR becomes:

@T.prim_func
def matmul_after_reorder_block_iter_var(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32"),
    C: T.Buffer((128, 128), "float32"),
):
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vk, vj, vi = T.axis.remap("RSS", [k, j, i])
            T.reads(A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

See Also#

reorder

参数:
  • block (BlockRV)

  • new_order (List[int])

返回类型:

None

reverse_compute_at(block, loop, preserve_unit_loops=False, index=-1)[源代码]#

Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the loops induced by the block so that the buffer region consumed by the consumer block could cover those regions produced by its producer blocks under the given loop. It requires:

  1. block and loop are under the same scope, loop is not the ancestor of block

  2. The scope block has stage-pipeline property

3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow condition. i.e. all the blocks in the scope block’s subtree must be either complete block or reduction block

  1. All the producers of the block are under the given loop

Parameters#

blockUnion[BlockRV, str]

The block to be moved

loop: LoopRV

The loop where the block to be moved under

preserve_unit_loops: bool

Whether to keep the trivial loops whose extents are 1

index: int

The block index of the loop body subtree blocks: - index = -1 means inserted into the last possible insertion point; - index = -2 means inserted into the first possible insertion point; - Otherwise, index is a nonnegative number that indicates the insertion point

Examples#

Before reverse-compute-at, in TensorIR, the IR is:

@T.prim_func
def before_reverse_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do reverse-compute-at:

sch = tir.Schedule(before_reverse_compute_at)
block = sch.get_block("C")
loop, _ = sch.get_loops(sch.get_block("B"))
sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
print(sch.mod["main"].script())

After applying reverse-compute-at, the IR becomes:

@T.prim_func
def after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i in T.serial(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
        for j in T.serial(0, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = B[vi, vj] + 1.0
参数:
  • block (BlockRV | str)

  • loop (LoopRV)

  • preserve_unit_loops (bool)

  • index (int)

返回类型:

None

reverse_compute_inline(block)[源代码]#

Inline a block into its only producer. It requires:

  1. The block is a complete non-root block, which only produces and consumes one buffer

  2. The block must not be the only leaf in the scope.

  3. The only producer of the block is a read-after-write producer and a complete non-root block

  4. The body of the block must be a BufferStore statement in the form of, B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) where the indices of each BufferLoad on the RHS are all distinct atomic variables, and no variables other than those indexing variables are allowed in the statement.

Parameters#

blockUnion[BlockRV, str]

The block to be inlined to its producer

Examples#

Before reverse-compute-inline, in TensorIR, the IR is:

@T.prim_func
def before_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do reverse-compute-inline:

sch = tir.Schedule(before_inline)
sch.reverse_compute_inline(sch.get_block("C"))
print(sch.mod["main"].script())

After applying reverse-compute-inline, the IR becomes:

@T.prim_func
def after_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
参数:

block (BlockRV | str)

返回类型:

None

rfactor(loop, factor_axis)[源代码]#

Factorize an associative reduction block by the specified loop.

An associative reduction cannot be parallelized directly, because it leads to potential race condition during accumulation. Alternatively, the reduction could be factorized on a loop with the following steps: - Step 1: evenly slice the reduction into n separate chunks, where n is the loop extent - Step 2: compute the chunks separately and write the result into n intermediate buffers; - Step 3: accumulate the n separate buffer into the result buffer. Note that the Step 2 above introduces opportunities for parallelization.

RFactor is a schedule primitive that implements the transformation described above: Given a block that writes to buffer B, it factorizes a loop of extent n.

For example, the pseudocode below accumulates B[i] = sum(A[i, : , : ]):

for i in range(128):                    # loop i is a data parallel loop
    for j in range(128):                # loop j is a reduction loop
        for k in range(128):            # loop k is a reduction loop
            B[i] = B[i] + A[i, j, k]

Suppose RFactor is applied on the innermost loop k and factor_axis = 1. RFactor then creates an intermediate buffer and two blocks.

1. The intermediate buffer, or “rf-buffer” is a buffer of rank ndim(B) + 1 and size size(B) * n, whose shape expands from shape(B) by adding an axis of n at the position specified by factor_axis. For example,

  • shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]

  • shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]

  • shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]

  • shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]

2. The rfactor block, or “rf-block”, is a block that writes to the rf-buffer without accumulating over the loop k, i.e. the loop k is converted from a reduction loop to a data parallel loop. In our example, the rf-block is:

B_rf = np.zeros((128, 128))     # the rf-buffer
for k in range(128):            # loop k is converted to a data parallel loop
    for i in range(128):        # loop i is a data parallel loop (unchanged)
        for j in range(128):    # loop j is a reduction loop (unchanged)
            B_rf[i, k] = B_rf[i, k] + A[i, j, k]

3. The write-back block, or wb-block, is a block that accumulates the rf-buffer into the result buffer. All the reduction loops are removed except the loop k for accumulation. In our example, the wb-block is:

for i in range(128):            # loop i is a data parallel loop (unchanged)
                                # loop j is removed because it is a reduction loop
    for k in range(128):        # loop k is a reduction loop (unchanged)
        B[i] = B[i] + B_rf[i, k]

Parameters#

loopLoopRV

The loop outside block for which we want to do rfactor

factor_axisint

The position where the new dimension is placed in the new introduced rfactor buffer

Returns#

rf_blockBlockRV

The block which computes partial results over each slices (i.e., the first block as described in the above illustration)

Examples#

Before rfactor, in TensorIR, the IR is:

@T.prim_func
def before_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128))
    B = T.match_buffer(b, (128,))
    for ii, i, j in T.grid(128, 128, 128):
    with T.block("B"):
        vii, vi, vj = T.axis.remap("SRR", [ii, i, j])
        with T.init():
            B[vii] = 0.0
        B[vii] = B[vii] + A[vii, vi, vj]

Create the schedule and do rfactor:

sch = tir.Schedule(before_rfactor)
_, _, k = sch.get_loops(sch.get_block("B"))
sch.rfactor(k, 0)
print(sch.mod["main"].script())

After applying rfactor, the IR becomes:

@T.prim_func
def after_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128, 128])
    B = T.match_buffer(b, [128])
    B_rf = T.alloc_buffer([128, 128])
    for i2, ii, i in T.grid(128, 128, 128):
        with T.block("B_rf"):
            vi2, vii, vi = T.axis.remap("SSR", [i2, ii, i])
            with T.init():
                B_rf[vi2, vii] = 0.0
            B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
    for ii, i2 in T.grid(128, 128):
        with T.block("B"):
            vii, vi2 = T.axis.remap("SR", [ii, i2])
            with T.init():
                B[vii] = 0.0
            B[vii] = B[vii] + B_rf[vi2, vii]

Note#

Rfactor requires: 1) loop has only one child block, and it is a reduction block; 2) loop is a reduction loop, i.e. the loop variable is bound to only reduction variables in the block binding; 3) loop is not parallelized, vectorized, unrolled or bound to any thread axis; 4) The block scope that loop is in is a staged-pipeline; 5) The outermost loop outside the reduction block should has the reduction block as its first child block; 6) The outermost reduction loop should have only one child block; 7) An unary extent loop that is not bound to any reduction or data parallel variables in the block binding should not appear under some reduction loop; 8) The reduction block should write to only one buffer, and its init and body are both simple BufferStore`s, and the pattern is registered as an associative reducer. The pre-defined patterns include: plus, multiplication, min and max; 9) Each of the loops on top of the block cannot be bound to a data parallel and a reduction block binding at the same time; 10) `factor_axis should be in range [-ndim(B) - 1, ndim(B)], where B is the buffer that the reduction block writes to. Negative indexing is normalized according to numpy convention.

参数:
  • loop (LoopRV)

  • factor_axis (int)

返回类型:

BlockRV

rolling_buffer(block, write_buffer_index)[源代码]#

Compute the target buffer via rolling buffering, select the outermost rollable axis with a positive bound overlap that appears in the block’s ancestor loops as rolling axis, fold and circularize the buffer along the rolling dimension, append block predicate to avoid recomputing overlapping elements. It requires:

  1. The block is not an output block and has only RAW dependencies.

  2. The buffer to be an intermediate buffer defined via alloc_buffer.

3) The LCA of the producer and consumer of the buffer is a for loop, typically, the producer and consumer of the buffer are cascaded through compute_at.

4) The access region of the buffer has at least one dimension that contains a positive bound overlap.

Parameters#

blockUnion[BlockRV, str]

The producer block of the buffer.

write_buffer_indexint

The index of the buffer in block’s write region.

Examples#

Before rolling_buffer, in TensorIR, the IR is:

@T.prim_func
def before_rolling_buffer(
    A: T.Buffer((12, 12), "int8"), C: T.Buffer((8, 8), "int8")
) -> None:
    # body
    # with T.block("root")
    B = T.alloc_buffer([10, 10], dtype="int8")
    for i0, i1 in T.grid(2, 2):
        for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
            with T.block("B"):
                ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                B[ax0_1, ax1_1] = T.max(
                    B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]
                )
        for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
            with T.block("C"):
                ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                C[ax0_1, ax1_1] = T.max(
                    C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1]
                )

Create the schedule and do rolling_buffer:

sch = tir.Schedule(before_rolling_buffer)
sch.rolling_buffer(sch.get_block("B"), write_buffer_index=0)
print(sch.mod["main"].script())

After applying rolling_buffer, the IR becomes:

@T.prim_func
def after_rolling_buffer(
    A: T.Buffer((12, 12), "int8"),
    C: T.Buffer((8, 8), "int8")
) -> None:
    # body
    # with T.block("root")
    B = T.alloc_buffer([6, 10], dtype="int8")
    for i0, i1 in T.grid(2, 2):
        for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
            with T.block("B"):
                T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1))
                ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                B[ax0_1 % 6, ax1_1] = T.max(
                    B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]
                )
        for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
            with T.block("C"):
                ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                C[ax0_1, ax1_1] = T.max(
                    C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1]
                )

Note#

The region_cover property of the consumer block of the target buffer will become false.

参数:
  • block (BlockRV | str)

  • write_buffer_index (int)

返回类型:

None

sample_categorical(candidates, probs, decision=None)[源代码]#

Sample an integer given the probability distribution

Parameters#

candidatesList[int]

The candidates to be sampled from

probsList[float]

The probability of each candidate

decisionOptional[int]

The sampling decision, if any

Returns#

resultExprRV

The random variable sampled from candidates

参数:
返回类型:

PrimExpr

sample_compute_location(block, decision=None)[源代码]#

Sample a compute-at location of the given block

Parameters#

blockUnion[BlockRV, str]

The block whose compute-at location is to be sampled

decisionOptional[int]

The sampling decision

Returns#

resultLoopRV

The sampled loop where the input block is to be computed at

参数:
  • block (BlockRV | str)

  • decision (int | None)

返回类型:

LoopRV

sample_partitioned_tile(loop, n, partition_pos=0, innerpart_factor=1, decision=None)[源代码]#

Sample the factors to a partitioned tile for a specific loop

Parameters#

loopLoopRV

The loop to be tiled

nint

The number of tiles to be sampled

partition_posint

The position to partition tiles to two parts

innerpart_factorint

The factor of the second part

decision: Optional[List[int]]

The sampling decision, if any

Returns#

resultList[ExprRV]

A list of length n, the random partitioned tile sizes sampled

参数:
  • loop (LoopRV)

  • n (int)

  • partition_pos (int)

  • innerpart_factor (int)

  • decision (List[int] | None)

返回类型:

List[PrimExpr]

sample_perfect_tile(loop, n, max_innermost_factor=16, decision=None)[源代码]#

Sample the factors to perfect tile a specific loop

Parameters#

loopLoopRV

The loop to be tiled

nint

The number of tiles to be sampled

max_innermost_factorint

The maximum tile size allowed to be sampled in the innermost loop

decision: Optional[List[int]]

The sampling decision, if any

Returns#

resultList[ExprRV]

A list of length n, the random perfect tile sizes sampled

参数:
  • loop (LoopRV)

  • n (int)

  • max_innermost_factor (int)

  • decision (List[int] | None)

返回类型:

List[PrimExpr]

seed(seed)[源代码]#

Seed the randomness

Parameters#

seedint

The new random seed, -1 if use device random, otherwise non-negative

参数:

seed (int)

返回类型:

None

set_axis_separator(block, buffer, axis_separators)[源代码]#

Set the axis separator of a buffer, where the buffer is specified by a block and a read or write index.

Parameters#

block : Union[BlockRV, str]

The block that accesses the target buffer. If a string, this must uniquely identify a block.

buffer: Union[Tuple[str,int], Buffer, str]

The buffer to be transformed, or a specification of how to identify the buffer to be transformed.

If buffer if a tuple of (str,int), the first item should be either “read” or “write”, and the second item is an index into the block’s read or write regions.

If buffer is a string, it is the name of the buffer, which must exist within the reads/writes of the block. In addition, the reads/writes of the block may not contain more than one buffer with this name.

If buffer is a Buffer object, it must exist within the reads/writes of the block.

axis_separators : Optional[List[int]]

The axis separators.

Examples#

Before set_axis_separator, in TensorIR, the IR is:

@T.prim_func
def before_set_axis_separator(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do set_axis_separator:

sch = tir.Schedule(before_set_axis_separator)
sch.set_axis_separators(sch.get_block("B"), buffer=("write", 0),
                        axis_separators=[1])
print(sch.mod["main"].script())

After applying set_axis_separator, the IR becomes:

@T.prim_func
def after_set_axis_separators(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + T.float32(1)
参数:
返回类型:

None

set_scope(block, buffer_index, storage_scope)[源代码]#

Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index.

Parameters#

blockUnion[BlockRV, str]

The producer block of the buffer

buffer_indexint

The index of the buffer in block’s write region

storage_scopestr

The storage scope to be set

Examples#

Before set_scope, in TensorIR, the IR is:

@T.prim_func
def before_set_scope(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do set_scope:

sch = tir.Schedule(before_set_scope)
sch.set_scope(sch.get_block("B"), buffer_index=0, storage_scope="shared")
print(sch.mod["main"].script())

After applying set_scope, the IR becomes:

@T.prim_func
def after_set_scope(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B_shared[vi, vj] = A[vi, vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B_shared[vi, vj] + T.float32(1)

Note#

set_scope requires the buffer to be an intermediate buffer defined via alloc_buffer.

参数:
返回类型:

None

show(*args, **kwargs)[源代码]#

A sugar for print highlighted TVM script.

All parameters are forwarded to the underlying Module.show and Trace.show methods.

返回类型:

None

split(loop, factors, preserve_unit_iters=True, disable_predication=False)[源代码]#

Split a loop into a list of consecutive loops. It requires: 1) The loop can’t have annotation or thread binding. 2) The loop must start with 0. Predicates may be added to ensure the total loop numbers keeps unchanged. In factors, at most one of the factors can be None, which will be automatically inferred.

Parameters#

loopLoopRV

The loop to be split

factors: List[Union[int, ExprRV, None]]

The splitting factors Potential inputs are: - None - ExprRV - Positive constant integers

preserve_unit_itersbool

Whether or not to preserve unit iterators in block bindings

disable_predicationbool

If enabled, don’t create a predicate for guarding the loop. This can be useful when splitting with scalable factors that the schedule writer knows are divisible by the loop bound.

Warning: enabling this feature may result in incorrect code generation if not used carefully.

Returns#

split_loopsList[LoopRV]

The new loops after split

Examples#

Before split, in TensorIR, the IR is:

@T.prim_func
def before_split(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do split:

sch = tir.Schedule(before_split)
i, j = sch.get_loops(sch.get_block("B"))
sch.split(i, factors=[2, 64])
print(sch.mod["main"].script())

After applying split, the IR becomes:

@T.prim_func
def after_split(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # the original loop is split into 2 loops
    for i0, i1, j in T.grid(2, 64, 128):
        with T.block("B"):
            vi = T.axis.S(128, i0 * 64 + i1)
            vj = T.axis.S(128, j)
            B[vi, vj] = A[vi, vj] * 2.0
参数:
返回类型:

List[LoopRV]

storage_align(block, buffer_index, axis, factor, offset)[源代码]#

Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared memory.

Parameters#

blockUnion[BlockRV, str]

The producer block of the buffer.

buffer_indexint

The index of the buffer in block’s write region.

axisint

The dimension to be specified for alignment.

factorint

The factor multiple of alignment.

offsetint

The required offset factor.

Examples#

Before storage_align, in TensorIR, the IR is:

@T.prim_func
def before_storage_align(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do storage_align:

sch = tir.Schedule(before_storage_align)
sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
print(sch.mod["main"].script())

After applying storage_align, the IR becomes:

@T.prim_func
def after_storage_align(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

After lowering passes, buffer B will have strides as [129, 1].

Note#

Storage_align requires the buffer to be an intermediate buffer defined via alloc_buffer.

参数:
  • block (BlockRV | str)

  • buffer_index (int)

  • axis (int)

  • factor (int)

  • offset (int)

返回类型:

None

tensorize(block_or_loop, tensor_intrin, preserve_unit_iters=True)[源代码]#

Tensorize the computation enclosed by loop with the tensor intrinsic.

Parameters#

block_or_loopUnion[BlockRV, LoopRV]

The loop to be tensorized.

tensor_intrinstr

The tensor intrin or the name of the tensor intrin.

preserve_unit_itersbool

Whether or not to preserve unit iterators in block bindings

Examples#

Before tensorize, in TensorIR, the IR is:

@T.prim_func
def before_tensorize(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32"),
    C: T.Buffer((128, 128), "float32"),
) -> None:
    # body
    # with T.block("root")
    for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16):
        with T.block("update"):
            vi = T.axis.spatial(128, i_0 * 16 + i_1)
            vj = T.axis.spatial(128, j_0 * 16 + j_1)
            vk = T.axis.reduce(128, k_0 * 16 + k_1)
            T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Declare and register the tensor intrinsic:

@T.prim_func
def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

    with T.block("root"):
        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
        T.writes(C[0 : 16, 0 : 16])
        for i, j, k in T.grid(16, 16, 16):
            with T.block("update"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@T.prim_func
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

    with T.block("root"):
        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
        T.writes(C[0 : 16, 0 : 16])
        T.evaluate(
            T.tvm_mma_sync(
                C.data,
                C.elem_offset // 256,
                A.data,
                A.elem_offset // 256,
                B.data,
                B.elem_offset // 256,
                C.data,
                C.elem_offset // 256,
                dtype="handle",
            )
        )

tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)

Create the schedule and do tensorize:

sch = tir.Schedule(before_tensorize)
update = sch.get_block("update")
_, _, _, i1, _, _ = sch.get_loops(update)
sch.tensorize(i1, "test_mma_intrin")
print(sch.mod["main"].script())

After applying tensorize, the IR becomes:

@T.prim_func
def after_tensorize(
    A: T.Buffer((128, 128), "float32"),
    B: T.Buffer((128, 128), "float32"),
    C: T.Buffer((128, 128), "float32"),
) -> None:
    # body
    # with T.block("root")
    for i_0, j_0, k_0 in T.grid(8, 8, 8):
        with T.block("update_o"):
            vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0])
            T.reads(
                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16],
                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16],
            )
            T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
            A_1 = T.match_buffer(
                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16],
                [16, 16],
                dtype="float32",
                offset_factor=1,
            )
            B_1 = T.match_buffer(
                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16],
                [16, 16],
                dtype="float32",
                offset_factor=1,
            )
            C_1 = T.match_buffer(
                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
                [16, 16],
                dtype="float32",
                offset_factor=1,
            )
            T.evaluate(
                T.tvm_mma_sync(
                    C_1.data,
                    C_1.elem_offset // 256,
                    A_1.data,
                    A_1.elem_offset // 256,
                    B_1.data,
                    B_1.elem_offset // 256,
                    C_1.data,
                    C_1.elem_offset // 256,
                    dtype="handle",
                )
            )
参数:
  • block_or_loop (BlockRV | LoopRV)

  • tensor_intrin (str)

  • preserve_unit_iters (bool)

返回类型:

None

transform_block_layout(block, index_map)[源代码]#

Apply a transformation represented by IndexMap to block

Parameters#

blockUnion[BlockRV, str]

The block to be transformed

index_mapUnion[IndexMap, Callable]

The transformation to apply.

Examples#

Before transform_block_layout, in TensorIR, the IR is:

@T.prim_func
def before_transform_block_layout(
    A: T.Buffer((16, 16), "float32"),
    B: T.Buffer((16, 16), "float32")
) -> None:
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do transform_block_layout:

sch = tir.Schedule(before_transform_block_layout)
sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,))
print(sch.mod["main"].script())

After applying transform_block_layout, the IR becomes:

@T.prim_func
def after_transform_block_layout(
    A: T.Buffer((16, 16), "float32"),
    B: T.Buffer((16, 16), "float32")
) -> None:
    for i in range(256):
        with T.block("B"):
            vi, = T.axis.remap("S", [i])
            B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
参数:
返回类型:

None

transform_layout(block, buffer, index_map, pad_value=None, *, assume_injective_transform=False)[源代码]#

Apply a transformation represented by IndexMap to buffer

Parameters#

block : Union[BlockRV, str]

The block that accesses the target buffer. If a string, this must uniquely identify a block.

buffer: Union[Tuple[str,int], Buffer, str]

The buffer to be transformed, or a specification of how to identify the buffer to be transformed.

If buffer if a tuple of (str,int), the first item should be either “read” or “write”, and the second item is an index into the block’s read or write regions.

If buffer is a string, it is the name of the buffer, which must exist within the reads/writes of the block. In addition, the reads/writes of the block may not contain more than one buffer with this name.

If buffer is a Buffer object, it must exist within the reads/writes of the block.

index_map : Union[IndexMap, Callable]

The transformation to apply.

If index_map is a callable, and the returned list contains IndexMap.AXIS_SEPARATOR, the SetAxisSeparators primitive will be called in addition to the TransformLayout primitive.

pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]

The value to be used for any padding introduced by the transformation. If the schedule contains a producer block for the specified buffer, the pad value will be written as part of the producer block if possible, or after the producer block otherwise. Otherwise, if the buffer is an input, will insert an annotation block to state that the padding contains the known value.

The pad value may not contain instances of BufferLoad, except where it loads a value from the buffer being transformed (e.g. to create a circular buffer with padding that consists of repeated elements).

Note: If applied to an input buffer, the calling scope is responsible for ensuring that the pad_value is present. Algebraic symplifications, branch elimination, and other optimizations may assume that this precondition is met, and may result in incorrect results being returned.

If None, the transformation may not introduce padding.

If an int, float or PrimExpr, the transformation is the specific value to be present in the padding.

If an IndexMap or Callable, the transformation is the value to be present in the padding in terms of the transformed index.

assume_injective_transform : bool

If set to true, the schedule primitive will assume the index_map is injective and skip checking overlapping of the mapped indices. This can be useful for complicated index_map that the analysis does not cover. It is the callers’ responsibility to ensure the index map is injective, otherwise, the correctness of the schedule is not guaranteed.

Examples#

Before transform_layout, in TensorIR, the IR is:

@T.prim_func
def before_transform_layout(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do transform_layout:

sch = tir.Schedule(before_storage_align)
sch.transform_layout(sch.get_block("B"), buffer=("write",0),
                     index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
print(sch.mod["main"].script())

After applying transform_layout, the IR becomes:

@T.prim_func
def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((8, 8, 16, 16), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0
参数:
返回类型:

None

unannotate(block_or_loop, ann_key)[源代码]#

Unannotate a block/loop’s annotation with key ann_key

Parameters#

block_or_loop: Union[BlockRV, LoopRV]

The block/loop to be unannotated

ann_keystr

The annotation key

Examples#

Before unannotate, in TensorIR, the IR is:

@T.prim_func
def before_unannotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.block_attr({"ann_key", "ann_value"})
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do annotate:

sch = tir.Schedule(before_unannotate)
sch.unannotate(sch.get_block("B"), "ann_key")
print(sch.mod["main"].script())

After applying unannotate, the IR becomes:

@T.prim_func
def after_unannotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
参数:
  • block_or_loop (BlockRV | LoopRV)

  • ann_key (str)

返回类型:

None

unroll(loop)[源代码]#

Unroll the input loop. It requires nothing

Parameters#

loopLoopRV

The loop to be unrolled

Examples#

Before unroll, in TensorIR, the IR is:

@T.prim_func
def before_unroll(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do unroll:

sch = tir.Schedule(before_unroll)
i, j = sch.get_loops(sch.get_block("B"))
sch.unroll(i)

After applying unroll, the IR becomes:

@T.prim_func
def after_unroll(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.unroll(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
参数:

loop (LoopRV)

返回类型:

None

unsafe_hide_buffer_access(block, buf_type, buf_index_array)[源代码]#

Hide some buffer access in a given block. This is an unsafe schedule primitive.

Parameters#

blockBlockRV

The block where we hide read access.

buf_typestr

The buffer type: “read”/”write”.

buf_index_arrayList[int]

The array of buffer indices we hide access.

Note#

This schedule primitive is unsafe, and may fail dependency analysis. One use case of unsafe_hide_buffer_access is to hide the buffer access to indices buffers (e.g. in sparse computation) so that we can further tensorize the block (the indices buffers appeared in read/write regions may fail the pattern matching in tensorize primitive, and hide the access to these buffers could address the issue).

参数:
  • block (BlockRV)

  • buf_type (str)

  • buf_index_array (List[int])

返回类型:

None

unsafe_set_dtype(block, buffer_index, dtype)[源代码]#

Set the data type of a buffer, where the buffer is specified by the a block and write-index.

This schedule primitive is unsafe and may change the correctness of program because of type conversion, please use with caution.

Parameters#

blockUnion[BlockRV, str]

The producer block of the buffer

buffer_indexint

The index of the buffer in block’s write region

dtypestr

The data type to be set

Examples#

Before unsafe_set_dtype, in TensorIR, the IR is:

@T.prim_func
def before_set_dtype(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j]
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do unsafe_set_dtype:

sch = tir.Schedule(before_set_dtype)
sch.unsafe_set_dtype("B", buffer_index=0, dtype="float16")
print(sch.mod["main"].script())

After applying set_dtype, the IR becomes:

@T.prim_func
def after_set_dtype(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float16")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j]
            C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0

Note#

unsafe_set_dtype requires the buffer to be an intermediate buffer defined via alloc_buffer.

参数:
  • block (BlockRV | str)

  • buffer_index (int)

  • dtype (str)

返回类型:

None

vectorize(loop)[源代码]#

Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipeline property 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine bindings 3) For each block under the loop, the loop can only be contained in data-parallel block iters’ bindings

Parameters#

loopLoopRV

The loop to be vectorized

Examples#

Before vectorize, in TensorIR, the IR is:

@T.prim_func
def before_vectorize(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do vectorize:

sch = tir.Schedule(before_vectorize)
i, j = sch.get_loops(sch.get_block("B"))
sch.vectorize(j)

After applying vectorize, the IR becomes:

@T.prim_func
def after_vectorize(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.serial(0, 128):
        for j in T.vectorized(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
参数:

loop (LoopRV)

返回类型:

None

work_on(func_name)[源代码]#

Instruct the schedule to work on a function in the IRModule.

By default, the schedule works on the function with the name “main”, or the only function in the IRModule if there is only one. If there is multiple functions in the IRModule, and none of their names are “main”, users will have to call this method to explicitly specify which function to work on.

This sugar function will guide the GetBlock method if its func_name is not specified.

Parameters#

func_namestr

The name of the function to work on.

参数:

func_name (str)

返回类型:

None

property func_working_on: GlobalVar | None#

Returns the GlobalVar of the func that the schedule is currently working on

property mod: IRModule#

Returns the AST of the module being scheduled

property state: ScheduleState#

Returns the ScheduleState in the current schedule class

property trace: Trace | None#

Returns the internally maintained trace of scheduling program execution

class tvm.tir.ScheduleState(mod, *, debug_mask='none', enable_check=True)[源代码]#

The state of scheduling, which exposes a Replace method as the primary resort for all the scheduling primitives to manipulate the TensorIR.

The data structure contains the following information 1) The AST being scheduled (mod) 2) The sref tree of schedulable statements (indicated by the srefs) 3) The dependency information of each block scope (block_info) 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref) 5) A debug flag, if set, extra checking is enabled (debug_mask) 6) A enable check flag, if False, some prerequisite checks are disabled.

Parameters#

modIRModule

The AST of the module being scheduled

debug_maskint

Do extra correctness checking after the object construction and each time after calling the Replace method.

enable_checkbool

Indicates whether we enable prerequisite checks for some schedule primitives or not, defaults to True.

Methods:

__init__(mod, *[, debug_mask, enable_check])

Construct a schedule state from an IRModule or a PrimFunc

_get_cached_flags(block_sref)

Get the cached flags of the corresponding block

get_block_scope(block_sref)

Get the BlockScope correpsonding to the block sref

get_sref(stmt)

Return the corresponding sref that points to the stmt

replace(src_sref, tgt_stmt[, block_sref_reuse])

Replace the part of the AST, as being pointed to by src_sref, with a specific statement tgt_stmt, and maintain the sref tree accordingly.

__init__(mod, *, debug_mask='none', enable_check=True)[源代码]#

Construct a schedule state from an IRModule or a PrimFunc

Parameters#

modUnion[PrimFunc, IRModule]

The IRModule or PrimFunc to be scheduled

debug_maskUnion[str, int]

Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of debug_mask: 1) “all” - Turn on all the checks 2) “none” - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask

参数:
返回类型:

None

_get_cached_flags(block_sref)[源代码]#

Get the cached flags of the corresponding block

Parameters#

block_srefStmtSRef

The block sref to be retrieved

Returns#

flagsCachedFlags

Three flags: affine_binding, region_cover, stage_pipeline

Note#

It is an API intended for internal testing use.

参数:

block_sref (StmtSRef)

返回类型:

CachedFlags

get_block_scope(block_sref)[源代码]#

Get the BlockScope correpsonding to the block sref

Parameters#

block_srefStmtSRef

The block sref to be retrieved

Returns#

srefStmtSRef

The corresponding sref

参数:

block_sref (StmtSRef)

返回类型:

BlockScope

get_sref(stmt)[源代码]#

Return the corresponding sref that points to the stmt

Parameters#

stmtUnion[Block, For]

The schedulable statement in the TensorIR to be retrieved for its sref

Returns#

srefStmtSRef

The corresponding sref

参数:

stmt (Block | For)

返回类型:

StmtSRef | None

replace(src_sref, tgt_stmt, block_sref_reuse=None)[源代码]#

Replace the part of the AST, as being pointed to by src_sref, with a specific statement tgt_stmt, and maintain the sref tree accordingly. Replace will try to perform copy on write as much as possible when the ScheduleState holds the only copy to the IRModule and IR nodes.

Only 3 types of replacements are allowed: from src_sref->stmt to tgt_stmt. 1) Block -> Block 2) Loop -> Loop 3) Loop -> BlockRealize

Parameters#

src_srefStmtSRef

The sref to the statement to be replaced in the TensorIR AST

tgt_stmtUnion[Block, For, BlockRealize]

The statement to be replaced to

block_sref_reuseOptional[Dict[Block, Block]] = None

Maps an old block (to be replaced in the subtree under src_sref->stmt) to a new block (replaced to, in the subtree under tgt_stmt), and enforces reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref that points to the old block will point to the new one

Note#

The reuse of loop srefs are detected automatically according to the reuse of loop vars.

参数:
返回类型:

None

参数:
class tvm.tir.Select(condition, true_value, false_value, span=None)[源代码]#

Select node.

Note#

Select may compute both true_value and false_value. Use tvm.tir.if_then_else instead if you want to get a conditional expression that only evaluates the correct branch.

Parameters#

conditionPrimExpr

The condition expression.

true_valuePrimExpr

The value to take when condition is true.

false_valuePrimExpr

The value to take when condition is false.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.SeqStmt(seq, span=None)[源代码]#

Sequence of statements.

Parameters#

seqList[Stmt]

The statements

spanOptional[Span]

The location of the stmt in the source code.

参数:
class tvm.tir.Shuffle(vectors, indices, span=None)[源代码]#

Shuffle node.

Parameters#

vectorsList[PrimExpr]

The vectors

indicesList[PrimExpr]

The indices

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.SizeVar(name, dtype, span=None)[源代码]#
Symbolic variable to represent a tensor index size

which is greater or equal to zero.

Parameters#

namestr

The name

dtypeUnion[str, ir.Type]

The data type

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Stmt[源代码]#

Base class of all the statements.

class tvm.tir.StmtSRef[源代码]#

An object that refers to schedulable elements in the TensorIR, aka “sref”.

Glossary - Block sref: An StmtSref that points to a TensorIR block. - Loop sref: An StmtSRef that points to a TensorIR for loop. - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest schedulable statement of its ancestors on the TensorIR AST. - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref. - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by the TensorIR AST.

Methods:

inline_mark()

A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do the work of compute-inline

root_mark()

A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do nothing

Attributes:

parent

The parent sref

stmt

The block/for stmt the object refers to

static inline_mark()[源代码]#

A special StmtSRef, which doesn’t point to any stmt in the AST, only serving as a “mark” to hint compute-at to do the work of compute-inline

返回类型:

StmtSRef

static root_mark()[源代码]#

A special StmtSRef, which doesn’t point to any stmt in the AST, only serving as a “mark” to hint compute-at to do nothing

返回类型:

StmtSRef

property parent: StmtSRef | None#

The parent sref

property stmt: Block | For | None#

The block/for stmt the object refers to

class tvm.tir.StringImm(value, span=None)[源代码]#

String constant.

Parameters#

valuestr

The value of the function.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.Sub(a, b, span=None)[源代码]#

Sub node.

Parameters#

aPrimExpr

The left hand operand.

bPrimExpr

The right hand operand.

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.TensorIntrin(desc, impl)[源代码]#

A tensor intrinsic.

Parameters#

descPrimFunc

The function to describe the computation.

implPrimFunc

The function of the implementation for the execution.

Methods:

get(name[, allow_missing])

Look up a tensor intrinsic by its name.

register(name, desc, impl[, override])

Register a tensor intrinsic with its name.

static get(name, allow_missing=False)[源代码]#

Look up a tensor intrinsic by its name.

Parameters#

namestr

The name of the TensorIntrin to look up.

allow_missingbool

Whether to allow missing tensor intrin. If False, raise an error if the tensor intrin

doesn’t exist.

Returns#

resultOptional[TensorIntrin]

The TensorIntrin with the specified name, or None if not found.

参数:
返回类型:

TensorIntrin | None

static register(name, desc, impl, override=False)[源代码]#

Register a tensor intrinsic with its name.

Parameters#

namestr

The name of the TensorIntrin to register.

descPrimFunc

The function to describe the computation.

implPrimFunc

The function of the implementation for the execution.

override: bool

Whether override existing intrinsic.

参数:
class tvm.tir.Var(name, dtype, span=None)[源代码]#

Symbolic variable.

Parameters#

namestr

The name

dtypeUnion[str, ir.Type]

The data type

spanOptional[Span]

The location of this expression in the source code.

参数:
class tvm.tir.While(condition, body, span=None)[源代码]#

While node.

Parameters#

conditionPrimExpr

The termination condition.

bodyStmt

The body statement.

spanOptional[Span]

The location of the stmt in the source code.

参数:
tvm.tir.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)[源代码]#

Backend function to allocate temporal workspace

Parameters#

device_typeint

The device type which the space will be allocated.

device_idint

The device id which the space will be allocated.

nbytesint

The size of the space requested.

dtype_code_hintint

The type code of the array elements. Only used in certain backends such as OpenGL.

dtype_bits_hintint

The type bits of the array elements. Only used in certain backends such as OpenGL.

Returns#

callPrimExpr

The call expression.

tvm.tir.TVMBackendFreeWorkspace(device_type, device_id, ptr)[源代码]#

Backend function to free temporal workspace.

Parameters#

device_typeint

The device type which the space will be allocated.

device_idint

The device id which the space will be allocated.

ptrVar

The result allocated space pointer.

Returns#

callPrimExpr

The call expression.

tvm.tir.abs(x, span=None)[源代码]#

Get absolute value of the input element-wise.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.acos(x)[源代码]#

Take acos of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.acosh(x)[源代码]#

Take acos of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.add(lhs, rhs, span=None)[源代码]#

Generic add operator.

Parameters#

lhsobject

The left operand.

rhsobject

The right operand.

spanOptional[Span]

The location of this operator in the source.

Returns#

optvm.Expr

The result Expr of add operaton.

tvm.tir.address_of(buffer_load, span=None)[源代码]#

Returns the address of an element in the buffer

Parameters#

buffer_load: BufferLoad

The buffer load.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.all(*args, span=None)[源代码]#
Create a new expression of the intersection of all conditions in the

arguments

Parameters#

argslist

List of symbolic boolean expressions

spanOptional[Span]

The location of this operator in the source code.

Returns#

expr: Expr

Expression

tvm.tir.any(*args, span=None)[源代码]#

Create a new experssion of the union of all conditions in the arguments

Parameters#

argslist

List of symbolic boolean expressions

spanOptional[Span]

The location of this operator in the source code.

Returns#

expr: Expr

Expression

tvm.tir.asin(x)[源代码]#

Take asin of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.asinh(x)[源代码]#

Take asinh of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.assume(cond=None)[源代码]#

Provide a true statement that can be used for simplifications

Parameters#

condExpr

The constraint condition.

Returns#

callPrimExpr

The call expression.

tvm.tir.atan(x)[源代码]#

Take atan of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.atan2(x1, x2)[源代码]#

Take arctan2(x1, x2).

Parameters#

x1PrimExpr

Input argument.

x2PrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.atanh(x)[源代码]#

Take atanh of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.bijective_layout(src_layout, dst_layout)[源代码]#

Create a bijective layout mapping.

Parameters#

src_layoutstr or Layout

source layout.

dst_layoutstr or Layout

destination layout.

Returns#

bijective_layoutBijectiveLayout

The created bijective layout

参数:
返回类型:

BijectiveLayout

tvm.tir.bitwise_and(x, y, span=None)[源代码]#

Take bitwise and of two values

Parameters#

xPrimExpr

Left operand

yPrimExpr

Right operand

spanOptional[Span]

The location of this operator in the source code.

Returns#

resPrimExpr

The result.

tvm.tir.bitwise_not(x, span=None)[源代码]#

Take bitwise not of input value

Parameters#

xPrimExpr

Input operand

spanOptional[Span]

The location of this operator in the source code.

Returns#

resPrimExpr

The result.

tvm.tir.bitwise_or(x, y, span=None)[源代码]#

Take bitwise or of two values

Parameters#

xPrimExpr

Left operand

yPrimExpr

Right operand

spanOptional[Span]

The location of this operator in the source code.

Returns#

resPrimExpr

The result.

tvm.tir.bitwise_xor(x, y, span=None)[源代码]#

Take bitwise xor of two values

Parameters#

xPrimExpr

Left operand

yPrimExpr

Right operand

spanOptional[Span]

The location of this operator in the source code.

Returns#

resPrimExpr

The result.

tvm.tir.call_cpacked(*args, span=None)[源代码]#

Build expression by call an external packed function.

Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters#

argslist of Expr or Buffer.

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

See Also#

te.extern : Create tensor with extern function call.

tvm.tir.call_cpacked_lowered(*args, span=None)[源代码]#

Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters#

argslist of Expr or Buffer.

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

See Also#

te.extern : Create tensor with extern function call.

tvm.tir.call_extern(dtype, func_name, *args, span=None)[源代码]#

Build expression by calling a extern function.

Parameters#

dtypestr

The data type of the result.

func_name: str

The extern function name.

argslist

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.call_intrin(dtype, func_name, *args, span=None)[源代码]#

Build expression by calling an intrinsic function.

Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.

Parameters#

dtypestr

The data type of the result.

func_name: str

The intrinsic function name.

argslist

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.call_llvm_intrin(dtype, name, *args, span=None)[源代码]#

Build expression by calling a llvm intrinsic function

Parameters#

dtypestr

The data type of the result.

namestr

The name of the llvm intrinsic function.

argslist

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.call_llvm_pure_intrin(dtype, name, *args, span=None)[源代码]#

Build expression by calling a pure llvm intrinsic function

Parameters#

dtypestr

The data type of the result.

namestr

The name of the llvm intrinsic function.

argslist

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.call_packed(*args, span=None)[源代码]#

Build expression by call an external packed function.

The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.

When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters#

argslist of Expr or Buffer.

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

See Also#

te.extern : Create tensor with extern function call.

tvm.tir.call_packed_lowered(*args, span=None)[源代码]#

Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters#

argslist of Expr or Buffer.

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

See Also#

te.extern : Create tensor with extern function call.

tvm.tir.call_pure_extern(dtype, func_name, *args, span=None)[源代码]#

Build expression by calling a pure extern function.

Parameters#

dtypestr

The data type of the result.

func_name: str

The extern function name.

argslist

Positional arguments.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.call_tir(global_var, *args)[源代码]#

Performs a call into another PrimFunc in the same IRModule

Returns#

callPrimExpr

The call expression.

参数:

global_var (GlobalVar)

tvm.tir.ceil(x, span=None)[源代码]#

Take ceil of float input x.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.ceildiv(lhs, rhs, span=None)[源代码]#

Generic ceildiv operator.

Parameters#

lhsobject

The left operand.

rhsobject

The right operand.

spanOptional[Span]

The location of this operator in the source.

Returns#

optvm.Expr

The result Expr of ceildiv operaton.

tvm.tir.clz(x)[源代码]#

Count leading zero bits of an integer x.

Parameters#

xPrimExpr

Input 32 or 64 bit integer. The result is undefined if the input is 0.

Returns#

yPrimExpr

The result.

tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')[源代码]#

Create a commutative reducer for reduction.

Parameters#

fcombinefunction(Expr -> Expr -> Expr)

A binary function which takes two Expr as input to return a Expr.

fidentityfunction(str -> Expr)

A function which takes a type string as input to return a const Expr.

Returns#

reducerfunction

A function which creates a reduce expression over axis. There are two ways to use it:

  1. accept (expr, axis, where) to produce an Reduce Expr on specified axis;

  2. simply use it with multiple Exprs.

Example#

n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tir.copysign(x1, x2)[源代码]#

Change the sign of x1 to that of x2, element-wise.

Parameters#

x1PrimExpr

Input argument.

x2PrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.cos(x)[源代码]#

Take cos of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.cosh(x)[源代码]#

Take cosh of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.create_barriers(barrier_count)[源代码]#

TVM intrinsic to create N barriers

Parameters#

barrier_countint

The number of barriers to create.

Returns#

callPrimExpr

The call expression.

tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=-1, offset_factor=0, buffer_type='', axis_separators=None, span=None)[源代码]#

Declare a new symbolic buffer.

Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout.

See the note below for detailed discussion on usage of buffer.

Parameters#

shapetuple of Expr

The shape of the buffer.

dtypestr, optional

The data type of the buffer.

namestr, optional

The name of the buffer.

dataVar, optional

The data pointer in the buffer.

strides: array of Expr

The stride of the buffer.

elem_offset: Expr, optional

The beginning offset of the array to data. In terms of number of elements of dtype.

scope: str, optional

The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory.

data_alignment: int, optional

The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM’s internal default.

offset_factor: int, optional

The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.

buffer_type: str, optional, {“”, “auto_broadcast”}

auto_broadcast buffer allows one to implement broadcast computation without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j’s shape equals 1.

axis_separatorslist of int, optional

If passed, a list of separators between groups of axes, each of which is flattened to an output axis. For flat memory spaces, should either be None, or an empty list.

span: Optional[Span]

The location of the decl_buffer creation in the source.

Returns#

buffertvm.tir.Buffer

The created buffer

Example#

Here’s an example of how broadcast buffer can be used to define a symbolic broadcast operation,

m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2")
n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2")
o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2")
A = te.placeholder((m0, m1, m2), name='A')
B = te.placeholder((n0, n1, n2), name='B')
C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = te.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
dev = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), dev)
fadd(a, b, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())

Note#

Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it.

If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic.

tvm.tir.div(a, b, span=None)[源代码]#

Compute a / b as in C/C++ semantics.

Parameters#

aPrimExpr

The left hand operand, known to be non-negative.

bPrimExpr

The right hand operand, known to be non-negative.

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

Note#

When operands are integers, returns truncdiv(a, b, span).

tvm.tir.end_profile_intrinsic(id)[源代码]#

End profile intrinsic. Parameters ———- id : int

The intrinsic id.

Returns#

callPrimExpr

The call expression.

tvm.tir.erf(x)[源代码]#

Take gauss error function of the input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.exp(x)[源代码]#

Take exponential of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.exp10(x)[源代码]#

Calculate 10**x

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.exp2(x)[源代码]#

Calculate 2**x

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.floor(x, span=None)[源代码]#

Take floor of float input x.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

参数:

x (PrimExprWithOp)

tvm.tir.floordiv(a, b, span=None)[源代码]#

Compute the floordiv of two expressions.

Parameters#

aPrimExpr

The left hand operand

bPrimExpr

The right hand operand

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

tvm.tir.floormod(a, b, span=None)[源代码]#

Compute the floormod of two expressions.

Parameters#

aPrimExpr

The left hand operand

bPrimExpr

The right hand operand

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

tvm.tir.fmod(x, y)[源代码]#

Return the remainder of x divided by y with the same sign as x.

Parameters#

xPrimExpr

Input argument.

yPrimExpr

Input argument.

Returns#

zPrimExpr

The result.

tvm.tir.get_active_lane_mask(dtype, base, limit)[源代码]#

Calculate a predicate mask given an upper bound (limit) and a current value (base).

It will be lowered to the llvm.get.active.lane.mask intrinsic. (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)

Parameters#

dtypestr

The data type of the result.

basePrimExpr

An expression reprsenting the base.

limitPrimExpr

An expression representing the limit.

tvm.tir.hypot(x1, x2)[源代码]#

Equivalent to sqrt(x1**2 + x2**2), element-wise.

Parameters#

x1PrimExpr

Input argument.

x2PrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.if_then_else(cond, t, f, span=None)[源代码]#

Conditional selection expression.

Parameters#

condPrimExpr

The condition

tPrimExpr

The result expression if cond is true.

fPrimExpr

The result expression if cond is false.

spanOptional[Span]

The location of this operator in the source.

Returns#

resultNode

The result of conditional expression.

Note#

Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.

tvm.tir.indexdiv(a, b, span=None)[源代码]#

Compute floor(a / b) where a and b are non-negative.

Parameters#

aPrimExpr

The left hand operand, known to be non-negative.

bPrimExpr

The right hand operand, known to be non-negative.

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

Note#

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.indexmod(a, b, span=None)[源代码]#

Compute the remainder of indexdiv. a and b are non-negative.

Parameters#

aPrimExpr

The left hand operand, known to be non-negative.

bPrimExpr

The right hand operand, known to be non-negative.

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

Note#

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.infinity(dtype, span=None)[源代码]#

infinity value of dtype

Parameters#

dtypestr

The data type.

spanOptional[Span]

The location of this operator in the source code.

Returns#

valuetvm.Expr

The infinity value of dtype.

参数:
返回类型:

Any

tvm.tir.isfinite(x, span=None)[源代码]#

Check if input value is finite.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.isinf(x, span=None)[源代码]#

Check if input value is infinite.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.isnan(x, span=None)[源代码]#

Check if input value is Nan.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.isnullptr(x, span=None)[源代码]#

Check if input value is nullptr.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.layout(layout_str, dtype='int32')[源代码]#

Create a layout node from a string.

Parameters#

layout_strstr

A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

dtypestr

The dtype of generated axes vars in the returned layout. It is required to be integer type.

Returns#

layoutLayout

The created layout

参数:
  • layout_str (str)

  • dtype (str)

返回类型:

Layout

tvm.tir.ldexp(x1, x2)[源代码]#

Returns x1 * (2 ** x2).

Parameters#

x1PrimExpr

Input argument.

x2PrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.likely(cond, span=None)[源代码]#

Mark condition as likely.

Parameters#

condPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The marked expression.

tvm.tir.log(x)[源代码]#

Take log of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.log10(x)[源代码]#

Take log10 of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.log1p(x)[源代码]#

Take log(x + 1) with respect to input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.log2(x)[源代码]#

Take log2 of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.lookup_param(param_name, span=None)[源代码]#

Returns the param by name

Parameters#

param_namestr

The name of param.

spanOptional[Span]

The location of this operator in the source code.

Returns#

callPrimExpr

The call expression.

tvm.tir.max(expr, axis, where=None, init=None, *args)#

Create a max expression over axis.

Parameters#

exprPrimExpr

The source expression.

axisIterVar

The reduction IterVar axis

whereoptional, Expr

Filtering predicate of the reduction.

Returns#

valuePrimExpr

The result value.

Example#

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this max reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.max represents tvm.te.max or tvm.tir.max.
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
max_res = tvm.max(m, n)
tvm.tir.max_value(dtype, span=None)[源代码]#

maximum value of dtype

Parameters#

dtypestr

The data type.

spanOptional[Span]

The location of this operator in the source code.

Returns#

valuetvm.Expr

The maximum value of dtype.

参数:
返回类型:

Any

tvm.tir.min(expr, axis, where=None, init=None, *args)#

Create a min expression over axis.

Parameters#

exprPrimExpr

The source expression.

axisIterVar

The reduction IterVar axis

whereoptional, Expr

Filtering predicate of the reduction.

Returns#

valuePrimExpr

The result value.

Example#

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this min reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.min represents tvm.te.min or tvm.tir.min.
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
min_res = tvm.min(m, n)
tvm.tir.min_value(dtype, span=None)[源代码]#

minimum value of dtype

Parameters#

dtypestr

The data type.

spanOptional[Span]

The location of this operator in the source code.

Returns#

valuetvm.Expr

The minimum value of dtype.

tvm.tir.mma_fill(dtype, local_size, local_ptr, offset)[源代码]#

TVM intrinsic for zero-initalizing an MMA accumulation registor

Parameters#

dtypestr

The data type of the result.

local_sizeIntImm

The number of elements.

local_ptrVar

The destination pointer variable.

offsetExpr

The destination offset.

Returns#

callPrimExpr

The call expression.

tvm.tir.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)[源代码]#

TVM intrinsic for storing the result of PTX MMA into a destination pointer

Parameters#

dtypestr

The data type of the result.

mIntImm

The shape of mma fragment.

nIntImm

The shape of mma fragment.

dst_ptrVar

The destination pointer variable.

src_ptrVar

The source pointer variable.

src_offsetExpr

The source offset.

dst_strideVar

The destination stride.

Returns#

callPrimExpr

The call expression.

tvm.tir.multiply(lhs, rhs, span=None)[源代码]#

Generic multiply operator.

Parameters#

lhsobject

The left operand.

rhsobject

The right operand.

spanOptional[Span]

The location of this operator in the source.

Returns#

optvm.Expr

The result Expr of multiply operaton.

tvm.tir.nearbyint(x, span=None)[源代码]#

Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.nextafter(x1, x2)[源代码]#

Return the next floating-point value after x1 towards x2.

Parameters#

x1PrimExpr

Input argument.

x2PrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.popcount(x)[源代码]#

Count the number of set bits in input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.pow(x, y, span=None)[源代码]#

x power y

Parameters#

xPrimExpr

Input argument.

yPrimExpr

The exponent

spanOptional[Span]

The location of this operator in the source code.

Returns#

zPrimExpr

The result.

tvm.tir.power(x, y, span=None)[源代码]#

x power y

Parameters#

xPrimExpr

Input argument.

yPrimExpr

The exponent

spanOptional[Span]

The location of this operator in the source code.

Returns#

zPrimExpr

The result.

tvm.tir.ptx_arrive_barrier(barrier_id)[源代码]#

TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

Parameters#

barrier_idint

The ID of the barrier shared memory pointer.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)[源代码]#

TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation

Parameters#

barrier_idint

The ID of the barrier shared memory pointer.

byte_countint

Increases the tx count of the mbarrier object to track completion of addtional async transactions.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_commit_group()[源代码]#

TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)[源代码]#

TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

Parameters#

dtypestr

The data type of the result.

shared_ptrVar

The shared memory pointer variable.

shared_offsetExpr

The offset of shared memory pointer.

global_ptrVar

The global memory pointer variable.

global_offsetExpr

The offset of global memory pointer.

bytesint

The data size to copy.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_cp_async_barrier(barrier_id)[源代码]#

TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

Parameters#

barrier_idint

The ID of the barrier shared memory pointer.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)[源代码]#

TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk

Parameters#

dtypestr

The data type of the result.

shared_ptrVar

The shared memory pointer variable.

shared_offsetExpr

The offset of shared memory pointer.

global_ptrVar

The global memory pointer variable.

global_offsetExpr

The offset of global memory pointer.

bytesint

The data size to copy.

barrier_idint

The ID of the barrier shared memory pointer.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_init_barrier_thread_count(barrier_id, thread_count)[源代码]#

TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

Parameters#

barrier_idint

The ID of the barrier shared memory pointer.

thread_countint

Number of threads expected to arrive at the barrier.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)[源代码]#

TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

Parameters#

dtypestr

The data type of the result.

transbool

The matrix is loaded in column-major format.

numIntImm

The number of matrices.

typeLiteral[“.b16”]

The data type of the matrices.

local_ptrVar

The local pointer variable.

local_offsetExpr

The offset of local pointer.

smem_ptrVar

The shared memory pointer variable.

smem_offsetExpr

The offset of shared memort pointer.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)[源代码]#

TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

Parameters#

dtypestr

The data type of the result.

shapestr

The shape of mma fragment.

A_layoutLiteral[“row”, “col”]

The layout of multiplicand fragment A.

B_layoutLiteral[“row”, “col”]

The layout of multiplicand fragment B.

A_dtypestr

The data type of multiplicand fragment A.

B_dtypestr

The data type of multiplicand fragment B.

C_dtypestr

The data type of accumulator fragment C.

multiplicand_aVar

The multiplicand fragment A variable.

a_indexExpr

The index of multiplicand fragment A.

multiplicand_bVar

The multiplicand fragment B variable.

b_indexExpr

The index of multiplicand fragment A.

accumulatorVar

The accumulator fragment C variable.

c_indexExpr

The index of accumulator fragment C.

saturatebool

The optional saturation at the output.

operatorOptional[Literal[“xor”, “and”]]

The 1-bit operator.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)[源代码]#

TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

Parameters#

dtypestr

The data type of the result.

shapestr

The shape of mma fragment.

A_layoutLiteral[“row”, “col”]

The layout of multiplicand fragment A.

B_layoutLiteral[“row”, “col”]

The layout of multiplicand fragment B.

A_dtypestr

The data type of multiplicand fragment A.

B_dtypestr

The data type of multiplicand fragment B.

C_dtypestr

The data type of multiplicand fragment C.

multiplicand_aVar

The multiplicand fragment A variable.

a_indexExpr

The index of multiplicand fragment A.

multiplicand_bVar

The multiplicand fragment B variable.

b_indexExpr

The index of multiplicand fragment B.

accumulatorVar

The accumulator fragment C variable.

c_indexExpr

The index of accumulator fragment C.

metadataExpr

The metadata of operand.

meta_indexExpr

The metadata index of operand.

sparse_selectorExpr

The sparse selector indicating the thread that stores the metadata.

saturatebool

The optional saturation at the output.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_wait_barrier(barrier_id)[源代码]#

TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

Parameters#

barrier_idint

The ID of the barrier shared memory pointer.

Returns#

callPrimExpr

The call expression.

tvm.tir.ptx_wait_group(num)[源代码]#

TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

Parameters#

numint

The number of the most recent uncommitted pending cp.async groups to wait.

Returns#

callPrimExpr

The call expression.

tvm.tir.q_multiply_shift(x, y, q, s)[源代码]#

Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:

out = round(x*y*2^-s)

More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)

Parameters#

xPrimExpr

First Q-number

yPrimExpr

Second Q-number

qPrimExpr

Number of fractional bits in x and y. Needs to be > 0

sPrimExpr

Integer shift

Returns#

yPrimExpr

The result.

tvm.tir.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required)[源代码]#

Execute a multiplication between two Q-numbers x and y

Parameters#

xPrimExpr

First Q-number.

yPrimExpr

Second Q-number.

lsPrimExpr

Integer left shift.

rsPrimExpr

Integer right shift.

qIntImm

Number of fractional bits in x and y. Needs to be > 0.

is_lshift_requiredIntImm

Whether we need to do left shift or not.

is_rshift_requiredIntImm

Whether we need to do right shift or not.

Returns#

zPrimExpr

The result.

参数:
tvm.tir.reinterpret(dtype, value, span=None)[源代码]#

infinity value of dtype

Parameters#

dtypestr

The data type.

valuePrimExpr

The input value.

spanOptional[Span]

The location of this operator in the source code.

Returns#

valuetvm.Expr

The reinterpret cast value of dtype.

参数:

span (Span | None)

返回类型:

Any

tvm.tir.ret(val)[源代码]#

Create a tir return expression

Parameters#

valExpr

The returned tir expression, whose data type is int, float or void pointer.

Returns#

retPrimExpr

The return expression

tvm.tir.round(x, span=None)[源代码]#

Round elements of the array to the nearest integer.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.rsqrt(x)[源代码]#

Take reciprocal of square root of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.shift_left(x, y, span=None)[源代码]#

Return the result of x left shifted by y bits.

Parameters#

xPrimExpr

Input argument.

yPrimExpr

Input argument.

Returns#

zPrimExpr

The result.

tvm.tir.shift_right(x, y, span=None)[源代码]#

Return the result of x right shifted by y bits.

Parameters#

xPrimExpr

Input argument.

yPrimExpr

Input argument.

Returns#

zPrimExpr

The result.

tvm.tir.sigmoid(x)[源代码]#

Quick function to get sigmoid

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.sin(x)[源代码]#

Take sin of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.sinh(x)[源代码]#

Take sinh of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.sqrt(x)[源代码]#

Take square root of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.start_profile_intrinsic(id)[源代码]#

Start profile intrinsic. Parameters ———- id : int

The intrinsic id.

Returns#

callPrimExpr

The call expression.

tvm.tir.stmt_list(stmt)[源代码]#

Make list of stmt from blocks.

Parameters#

stmtStmt

The input statement.

Returns#

stmt_listList[Stmt]

The unpacked list of statements

参数:

stmt (Stmt)

返回类型:

List[Stmt]

tvm.tir.stmt_seq(*args)[源代码]#

Make sequence of statements

Parameters#

*argsUnion[PrimExpr, Stmt]

List of statements to be combined as sequence.

Returns#

stmtStmt

The combined statement.

参数:

args (PrimExpr | Stmt)

返回类型:

SeqStmt

tvm.tir.subtract(lhs, rhs, span=None)[源代码]#

Generic subtract operator.

Parameters#

lhsobject

The left operand.

rhsobject

The right operand.

spanOptional[Span]

The location of this operator in the source.

Returns#

optvm.Expr

The result Expr of subtract operaton.

tvm.tir.sum(expr, axis, where=None, init=None, *args)#

Create a sum expression over axis.

Parameters#

exprPrimExpr

The source expression.

axisIterVar

The reduction IterVar axis

whereoptional, Expr

Filtering predicate of the reduction.

Returns#

valuePrimExpr

The result value.

Example#

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this sum reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.sum represents tvm.te.sum or tvm.tir.sum.
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
sum_res = tvm.sum(m, n)
tvm.tir.tan(x)[源代码]#

Take tan of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.tanh(x)[源代码]#

Take hyperbolic tanh of input x.

Parameters#

xPrimExpr

Input argument.

Returns#

yPrimExpr

The result.

tvm.tir.trace(args, trace_action='tvm.default_trace_action')[源代码]#

Trace tensor data at the runtime.

The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.

Parameters#

argslist of Expr or Buffers.

Positional arguments.

trace_actionstr.

The name of the trace action.

Returns#

callPrimExpr

The call expression.

See Also#

tvm.tir.call_packed : Creates packed function.

tvm.tir.trunc(x, span=None)[源代码]#

Get truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.

Parameters#

xPrimExpr

Input argument.

spanOptional[Span]

The location of this operator in the source code.

Returns#

yPrimExpr

The result.

tvm.tir.truncdiv(a, b, span=None)[源代码]#

Compute the truncdiv of two expressions.

Parameters#

aPrimExpr

The left hand operand

bPrimExpr

The right hand operand

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

Note#

This is the default integer division behavior in C.

tvm.tir.truncmod(a, b, span=None)[源代码]#

Compute the truncmod of two expressions.

Parameters#

aPrimExpr

The left hand operand

bPrimExpr

The right hand operand

spanOptional[Span]

The location of this operator in the source.

Returns#

resPrimExpr

The result expression.

Note#

This is the default integer division behavior in C.

tvm.tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)[源代码]#

Get head access address with memory access pattern info

Parameters#

ptypeExpr

The data type of pointer.

dataDType*

The data of pointer.

offsetint

The offset of pointer.

extentint

The extent of pointer.

rw_maskint

The read write mask.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)[源代码]#

TVM intrinsic for tensor core bmma_sync operators

Parameters#

fragment_dVar

The bwmma fragment_d.

index_dExpr

The fragment_d index.

fragment_aVar

The bwmma fragment_a.

index_aExpr

The fragment_a index.

fragment_bVar

The bwmma fragment_b.

index_bExpr

The fragment_b index.

fragment_cVar

The bwmma fragment_c.

index_cExpr

The fragment_c index.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_check_return(expected, return_unexpected, nested_call)[源代码]#

Return new on stack dtype[num] Parameters ———- expected : int

The expected return code.

return_unexpectedint

The unexpected return code.

nested_callPrimExpr

The call expression to check return.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_fill_fragment(fragment, m, n, k, index, value)[源代码]#

TVM intrinsic for tensor core fill_fragment operators

Parameters#

fragmentVar

The wmma fragment

mUIntImm

The shape of wmma fragment.

nUIntImm

The shape of wmma fragment.

kUIntImm

The shape of wmma fragment.

indexExpr

The fragment index.

valueExpr

The value to be filled in fragment.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)[源代码]#

TVM intrinsic for tensor core load operators

Parameters#

fragmentVar

The wmma fragment.

mUIntImm

The shape of wmma fragment.

nUIntImm

The shape of wmma fragment.

kUIntImm

The shape of wmma fragment.

indexExpr

The fragment index.

buffer_ptrExpr

The fragment buffer pointer.

strideExpr

The fragment stride.

layoutLiteral[“row_major”, “column_major”]

The fragment layout.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)[源代码]#

TVM intrinsic for tensor core mma_sync operators

Parameters#

fragment_dVar

The wmma fragment_d.

index_dExpr

The fragment_d index.

fragment_aVar

The wmma fragment_a.

index_aExpr

The fragment_a index.

fragment_bVar

The wmma fragment_b.

index_bExpr

The fragment_b index.

fragment_cVar

The wmma fragment_c.

index_cExpr

The fragment_c index.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_stack_alloca(dtype_str, num)[源代码]#

Return new on stack dtype[num]

Parameters#

dtype_strstr

The data type of array.

numint

The size of array.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)[源代码]#

Allocate a NDArray(DLTensor) on stack, return the handle

Parameters#

dataExpr

The data of array.

shapeExpr

The shape of array.

stridesExpr

The strides of array.

ndimExpr

The dimensions of array.

arr_dtypeExpr

The data type of array.

elem_offseExpr

The element offset of array.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_stack_make_shape(*args)[源代码]#

Allocate a shape tuple on stack, return the handle

Parameters#

argsint

The tuple shape.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)[源代码]#

TVM intrinsic for tensor core store operators

Parameters#

fragmentVar

The wmma fragment.

mUIntImm

The shape of wmma fragment.

nUIntImm

The shape of wmma fragment.

kUIntImm

The shape of wmma fragment.

indexExpr

The fragment index.

buffer_ptrExpr

The fragment buffer pointer.

strideExpr

The fragment stride.

layoutLiteral[“row_major”, “column_major”]

The fragment layout.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_struct_get(arr, index, field, dtype)[源代码]#

Get struct field value in array

Parameters#

dtypestr

The date type of the result.

arrStructType*

The array of struct.

indexint

The index of struct.

fieldint

The field of struct.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_struct_set(arr, index, field, value)[源代码]#

Set value in struct field in array

Parameters#

arrStructType*

The array of struct.

indexint

The index of struct.

fieldint

The field of struct.

valueExpr

The value to be set in field.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_thread_allreduce(*freduce_args)[源代码]#

Perform allreduce inside threadblock.

Parameters#

freduce_argsExpr

The args.

Returns#

callPrimExpr

The call expression.

tvm.tir.tvm_throw_last_error()[源代码]#

Throw TVMGetLastError()

Returns#

retPrimExpr

The return expression

tvm.tir.tvm_tuple(*value)[源代码]#

Create a tuple structure in value field of AttrStmt

Parameters#

valueExpr

The value in tuple.

Returns#

callPrimExpr

The call expression.

tvm.tir.type_annotation(dtype)[源代码]#

Create a type annotation expression

Parameters#

dtypeExpr

The data type.

Returns#

callPrimExpr

The call expression.

tvm.tir.undef()[源代码]#

Returns an initialized but arbitrary value

Returns#

callPrimExpr

The call expression.

tvm.tir.vectorcombine(dtype, vec1, vec2)[源代码]#

Concat two vectors

Parameters#

vec1list

The input vector.

vec2list

The input vector.

Returns#

callPrimExpr

The call expression.

tvm.tir.vectorhigh(dtype, vec)[源代码]#

Get the high level half of the vector

Parameters#

dtypestr

The data type of the result.

veclist

The input vector.

Returns#

callPrimExpr

The call expression.

tvm.tir.vectorlow(dtype, vec)[源代码]#

Get the low level half of the vector

Parameters#

dtypestr

The data type of the result.

veclist

The input vector.

Returns#

callPrimExpr

The call expression.

tvm.tir.vscale()[源代码]#

Get the target’s vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) Returns ——- call : PrimExpr

Call to the vscale intrinsic

tvm.tir.transform#

Namespace of all TIR transformations

Classes:

HoistedConditionals(value[, names, module, ...])

Flags for use in HoistExpressionConfig.conditional_types

HoistedLetBindings(value[, names, module, ...])

Flags for use in HoistExpressionConfig.let_binding_types

PrimFuncPass()

A pass that works on each tvm.tir.PrimFunc() in a module.

Functions:

AnnotateDeviceRegions()

Annotate locations that should be run on the device

AnnotateEntryFunc()

Set a PrimFunc as the entry point if it is only function in IRModule.

Apply(ftransform)

Apply ftransform to each function in the Module.

ApplyLayoutTransforms()

Reshape buffers that appear in the "layout_transform_map" fucntion attribute.

BF16ComputeLegalize()

Legalize bf16 compute Ops.

BF16StorageLegalize()

Legalize bf16 storage types to u16.

BindTarget(target)

Annotate a PrimFunc with a given target. Parameters ------- target : tvm.target.Target target.

CoProcSync()

Detect and insert sync points to co-processor.

CombineContextCall()

Combine context calls in the host function.

CommonSubexprElimTIR([enable_cse_tir, ...])

Replace redundant computations by new variables.

CompactBufferAllocation([is_strict])

Compact the buffer access region.

ConvertBlocksToOpaque()

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

ConvertForLoopsToSerial()

Convert Parallel For Loops to Serial For Loops.

ConvertSSA()

Convert an IRModule to be SSA form.

DecorateDeviceScope()

Decorate all the function's body as device function.

DefaultGPUSchedule()

The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, allowing their build and execution on GPU devices.

ExtractPrimFuncConstants()

Collects and unificates tir non-scalar constants to module's attr 'Constants' array.

FP8ComputeLegalize([promote_dtype_str])

Legalize fp8 compute Ops.

FP8StorageLegalize()

Legalize fp8 storage types to u8.

Filter(fcond)

Filter out PrimFuncs that does not satisfy the given condition.

FlattenBuffer()

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

ForceNarrowIndexToInt32()

Force narrow down indexing expressions and integer buffers to int32 dtype.

HoistExpression()

Generalized verison of HoistIfThenElse.

HoistIfThenElse([variant])

Hoist loop-invariant IfThenElse nodes to outside the eligible loops.

InferFragment()

Infer the TensorCore fragment infomation using tensor intrinsics.

InjectCopyIntrin(pragma_key, fintrin)

Inject virtual thread loops.

InjectDoubleBuffer()

Inject double buffer statements.

InjectPTXAsyncCopy()

Rewrite global to shared memory copy on CUDA with asyncronous copy.

InjectPermutedLayout()

Inject permuted layout in mma

InjectPrefetch()

Inject prefetch instructions into stmt.

InjectRollingBuffer()

Inject rolling buffer statements.

InjectSoftwarePipeline()

Transform annotated loops into pipelined one that parallelize producers and consumers

InjectVirtualThread()

Inject virtual thread loops.

InlinePrivateFunctions()

Inline calls to private functions

InstallDebugSpans()

Add line information from the TIR printer as spans on each statement and expression.

InstrumentBoundCheckers()

Instruments bound checkers.

InstrumentProfileIntrinsics()

Insert intrinsic calls to instrument function and loop level profiling.

LegalizePackedCalls()

Legalize packed calls to have its arguments wrapped in TVMValues

LiftAttrScope(attr_key)

Lift common attrs with attr_key to outer scope.

LiftThreadBinding()

Lift the same thread bindings to their LCA loops.

LoopPartition()

Inject virtual thread loops.

LowerAutoCopy()

Automatically do memory optimizations for auto copy blocks

LowerCrossThreadReduction()

Lower cross-thread reduction from thread bindings to intrinsic function calls.

LowerCustomDatatypes()

Lower custom datatypes.

LowerDeviceKernelLaunch()

Lower cross-device function calls.

LowerDeviceStorageAccessInfo()

Lower attached storage access information on device.

LowerInitBlock()

Lower block init stmt into IfThenElse statements.

LowerIntrin()

Lower target specific intrinsic calls.

LowerMatchBuffer()

Remove match buffers inside the block.

LowerOpaqueBlock()

Remove the block to ensure that the TIR can not be scheduled again.

LowerTVMBuiltin()

Lower tvm builtin intrinsics.

LowerThreadAllreduce()

Lower cross thread alleduce.

LowerWarpMemory()

Lower warp memory access to low-level device related function calls.

MakePackedAPI()

Transform the PrimFuncs in the module to a packed func API.

MakeUnpackedAPI()

Transform the PrimFuncs in the module to a C API compatible with internal calls.

ManifestSharedMemoryLocalStage()

Add the explicit local stage for the shared memory access on GPU.

MergeSharedMemoryAllocations()

This pass merges multiple TIR-level shared memory allocations into one allocation.

NarrowDataType(target_bits)

Narrow down PrimExpr datatype in stmt to target_bits.

PlanAndUpdateBufferAllocationLocation()

Locate the buffer allocation to the exact position (usually is the lca of buffer access).

PointerValueTypeRewrite()

Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible.

ReduceBranchingThroughOvercompute()

Reduce branching by introducing overcompute

RemoveAssume()

Remove all instances of builtin::assume

RemoveNoOp()

Remove No Op from the Stmt.

RemoveStoreUndef()

Remove stores of undefined values from the Stmt.

RemoveWeightLayoutRewriteBlock([...])

Remove weight layout rewrite block before benchmarking during tuning stage.

RenormalizeSplitPattern()

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

RewriteUnsafeSelect()

Detect and rewrite unsafe select that contains memory access.

Simplify()

Run arithmetic simplifications on the statements and expressions.

SkipAssert()

Skip assert stmt.

SplitHostDevice()

Split the function into a host function and device functions.

StorageFlatten(cache_line_size[, ...])

Flatten the multi-dimensional read/write to 1D.

StorageRewrite()

Rewrite storage allocation pattern.

TextureFlatten()

Flatten the multi-dimensional read/write to 2D.

ThreadSync(storage_scope)

Insert sync between parallel read/write of shared buffers.

TransformMmaBufferLayout()

Transform mma buffer layout

UnifyThreadBinding()

Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z".

UnrollLoop()

Unroll the constant loop marked by unroll.

VectorizeLoop([enable_vectorize])

Lower vectorization loops.

VerifyMemory()

Verify if func contains illegal host side direct memory access.

VerifyVTCMLimit(limit)

Verify if the size of the allocated vtcm memory satisfies the limit.

prim_func_pass([pass_func, opt_level, name, ...])

Decorate a function pass.

class tvm.tir.transform.HoistedConditionals(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#

Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be hoisted to the outermost loop possible.

Methods:

_generate_next_value_(start, count, last_values)

Generate the next value when not given.

Attributes:

All

Enable all hoisting of conditionals

BooleanExpression

If set, look for hoist candidates in all boolean expressions

IfElseExpr

If set, look for hoist candidates in tir.if_then_else

IfElseStmt

If set, look for hoist candidates in IfElseStmt

Never

No hoisting of conditionals

UsingBlockVar

If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x).

_generate_next_value_(start, count, last_values)#

Generate the next value when not given.

name: the name of the member start: the initial start value or None count: the number of existing members last_values: the last value assigned or None

All = 15#

Enable all hoisting of conditionals

BooleanExpression = 4#

If set, look for hoist candidates in all boolean expressions

IfElseExpr = 2#

If set, look for hoist candidates in tir.if_then_else

IfElseStmt = 1#

If set, look for hoist candidates in IfElseStmt

Never = 0#

No hoisting of conditionals

UsingBlockVar = 8#

If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x)

class tvm.tir.transform.HoistedLetBindings(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#

Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be hoisted to the outermost loop possible.

Methods:

_generate_next_value_(start, count, last_values)

Generate the next value when not given.

Attributes:

All

Enable all hoisting of let bindings

LetExpr

Bindings occuring in Let expressions

LetStmt

Bindings occuring in LetStmt

Never

No hoisting of let bindings

RequiredByConditional

Bindings that are used by a hoisted conditional

_generate_next_value_(start, count, last_values)#

Generate the next value when not given.

name: the name of the member start: the initial start value or None count: the number of existing members last_values: the last value assigned or None

All = 7#

Enable all hoisting of let bindings

LetExpr = 4#

Bindings occuring in Let expressions

LetStmt = 2#

Bindings occuring in LetStmt

Never = 0#

No hoisting of let bindings

RequiredByConditional = 1#

Bindings that are used by a hoisted conditional

class tvm.tir.transform.PrimFuncPass[源代码]#

A pass that works on each tvm.tir.PrimFunc() in a module. A function pass class should be created through py:func:tvm.tir.transform.function_pass.

tvm.tir.transform.AnnotateDeviceRegions()[源代码]#

Annotate locations that should be run on the device

Insert AttrStmt nodes specifying a target on which regions within the PrimFunc should be executed. Only modifies functions that have a tvm::attr::kTarget attribute, and where that target defines a host.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.AnnotateEntryFunc()[源代码]#

Set a PrimFunc as the entry point if it is only function in IRModule.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.Apply(ftransform)[源代码]#

Apply ftransform to each function in the Module.

This function is a thin wrapper around tvm.tir.transform.prim_func_pass

Parameters#

ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc

The transformation pass.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ApplyLayoutTransforms()[源代码]#

Reshape buffers that appear in the “layout_transform_map” fucntion attribute.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.BF16ComputeLegalize()[源代码]#

Legalize bf16 compute Ops.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.BF16StorageLegalize()[源代码]#

Legalize bf16 storage types to u16.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.BindTarget(target)[源代码]#

Annotate a PrimFunc with a given target. Parameters ——- target : tvm.target.Target

target

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.CoProcSync()[源代码]#

Detect and insert sync points to co-processor.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.CombineContextCall()[源代码]#

Combine context calls in the host function.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.CommonSubexprElimTIR(enable_cse_tir=True, identify_equiv_terms=False)[源代码]#

Replace redundant computations by new variables.

Returns#

fpasstvm.transform.Pass

The result pass

参数:
  • enable_cse_tir (bool)

  • identify_equiv_terms (bool)

tvm.tir.transform.CompactBufferAllocation(is_strict=True)[源代码]#

Compact the buffer access region. by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.

Example#

Before narrowing, B is a [16, 16] buffer, but only a skinny vector B[i, 0:16] is accessed.

for i in range(0, 16):
    with T.block():
        B = T.alloc_buffer(16, 16)
        for j in range(0, 16):
            B[i, j] = A[i, j] + 1
        for j in range(0, 16):
            C[i, j] = B[i, j] + 1

This pass narrows the buffer shape and adjust its accessed region accordingly. In this particular case, because only a 1 * 16 vector of B is accessed, the pass narrows B to shape [1, 16], and changes the access to B[i, j] to B[0, j].

for i in range(0, 16):
    with T.block():
        B = T.alloc_buffer(1, 16)
        for j in range(0, 16):
            B[0, j] = A[i, j] + 1
        for j in range(0, 16):
            C[i, j] = B[0, j] + 1

Parameters#

is_strictbool

Ensure the compacted shape to be always smaller than the original shape. Otherwise it allows to grow the shape to match actual accessed buffer regions.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

is_strict (bool)

tvm.tir.transform.ConvertBlocksToOpaque()[源代码]#

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ConvertForLoopsToSerial()[源代码]#

Convert Parallel For Loops to Serial For Loops.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ConvertSSA()[源代码]#

Convert an IRModule to be SSA form.

This pass handles cases where the same tir.Var appears in multiple functions within the same module. For example, after extracting a fragment from one function into another, where the same tir.Var may be defined both as within the body of the original function, and as a parameter within the hoisted function.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.DecorateDeviceScope()[源代码]#

Decorate all the function’s body as device function.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.DefaultGPUSchedule()[源代码]#

The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, allowing their build and execution on GPU devices. It examines all the blocks within the PrimFunc and conducts loop fusion, splitting, and reordering operation based on the loop extent and target information, such as the maximum thread block number and maximum thread per block.

The primary objective of this pass is not to optimize performance, but rather to generate a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass is currently only working for CUDA targets.

Returns#

ret: tvm.transform.Pass

tvm.tir.transform.ExtractPrimFuncConstants()[源代码]#

Collects and unificates tir non-scalar constants to module’s attr ‘Constants’ array.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.FP8ComputeLegalize(promote_dtype_str='float32')[源代码]#

Legalize fp8 compute Ops.

Parameters#

promote_dtypestr

The data type we promote fp8 to, options: float16/float32.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

promote_dtype_str (str)

tvm.tir.transform.FP8StorageLegalize()[源代码]#

Legalize fp8 storage types to u8.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.Filter(fcond)[源代码]#

Filter out PrimFuncs that does not satisfy the given condition. fcond should be a function that takes a primfunc and returns boolean.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

fcond (Callable)

tvm.tir.transform.FlattenBuffer()[源代码]#

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ForceNarrowIndexToInt32()[源代码]#

Force narrow down indexing expressions and integer buffers to int32 dtype.

Returns#

fpasstvm.transform.Pass

The result pass

Note#

This pass should not be used in default cases.

tvm.tir.transform.HoistExpression()[源代码]#

Generalized verison of HoistIfThenElse.

Hoist loop-invariant expressions to outside the eligible loops. Searches for expressions in:

  • LetStmt bindings

  • IfThenElse conditions

  • Boolean operators

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.HoistIfThenElse(variant=None)[源代码]#

Hoist loop-invariant IfThenElse nodes to outside the eligible loops.

Parameters#

variantOptional[String]

The variant of the pass. variant can have any one of following values [“basic”, None(Default)].

The basic variant supports basic hoisting scenarios where it expects the For & If Nodes are in place consecutively and does not involve global scope variables or more advanced scenarios.

Default variant supports all hoisting scenarios,i.e., {“Basic” + “Advanced”} supported with control with PassContext configs like below:

config={“tir.HoistIfThenElse”: {“support_block_scope_hosting”: True}}

Returns#

fpasstvm.transform.Pass

The result pass

参数:

variant (str | None)

tvm.tir.transform.InferFragment()[源代码]#

Infer the TensorCore fragment infomation using tensor intrinsics.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectCopyIntrin(pragma_key, fintrin)[源代码]#

Inject virtual thread loops.

Parameters#

pragma_keystr

The pragma key for hint of copy.

fintrinfunction

The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)

Returns#

fpasstvm.transform.Pass

The result pass

参数:

pragma_key (str)

tvm.tir.transform.InjectDoubleBuffer()[源代码]#

Inject double buffer statements.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectPTXAsyncCopy()[源代码]#

Rewrite global to shared memory copy on CUDA with asyncronous copy.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectPermutedLayout()[源代码]#

Inject permuted layout in mma

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectPrefetch()[源代码]#

Inject prefetch instructions into stmt.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectRollingBuffer()[源代码]#

Inject rolling buffer statements.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectSoftwarePipeline()[源代码]#

Transform annotated loops into pipelined one that parallelize producers and consumers

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InjectVirtualThread()[源代码]#

Inject virtual thread loops.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InlinePrivateFunctions()[源代码]#

Inline calls to private functions

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InstallDebugSpans()[源代码]#

Add line information from the TIR printer as spans on each statement and expression.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InstrumentBoundCheckers()[源代码]#

Instruments bound checkers.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.InstrumentProfileIntrinsics()[源代码]#

Insert intrinsic calls to instrument function and loop level profiling.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LegalizePackedCalls()[源代码]#

Legalize packed calls to have its arguments wrapped in TVMValues

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LiftAttrScope(attr_key)[源代码]#

Lift common attrs with attr_key to outer scope.

Parameters#

attr_keystr

The attribute key to be checked.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

attr_key (str)

tvm.tir.transform.LiftThreadBinding()[源代码]#

Lift the same thread bindings to their LCA loops.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LoopPartition()[源代码]#

Inject virtual thread loops.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerAutoCopy()[源代码]#

Automatically do memory optimizations for auto copy blocks

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerCrossThreadReduction()[源代码]#

Lower cross-thread reduction from thread bindings to intrinsic function calls.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerCustomDatatypes()[源代码]#

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerDeviceKernelLaunch()[源代码]#

Lower cross-device function calls.

Prior to this pass, host to device calls are represented as subroutine calls, with environment parameters (e.g. env_thread) specified internally. The device function is an internal function, without a tvm::attr::kGlobalSymbol attribute.

After this pass, host to device calls are represented as tvm_call_packed built-in. The device function is an externally-exposed function, with a non-empty tvm::attr::kGlobalSymbol attribute.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerDeviceStorageAccessInfo()[源代码]#

Lower attached storage access information on device.

Returns#

fpasstvm.transform.Pass

The result pass

Note#

Run this pass after all storage access analysis finish.

tvm.tir.transform.LowerInitBlock()[源代码]#

Lower block init stmt into IfThenElse statements.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerIntrin()[源代码]#

Lower target specific intrinsic calls.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerMatchBuffer()[源代码]#

Remove match buffers inside the block. Also, it will validate the binding.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerOpaqueBlock()[源代码]#

Remove the block to ensure that the TIR can not be scheduled again.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerTVMBuiltin()[源代码]#

Lower tvm builtin intrinsics.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerThreadAllreduce()[源代码]#

Lower cross thread alleduce.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.LowerWarpMemory()[源代码]#

Lower warp memory access to low-level device related function calls.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.MakePackedAPI()[源代码]#

Transform the PrimFuncs in the module to a packed func API.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate TVMArgs and TVMRetValue* arguments that implement the PackedFunc API.

For static shapes, the BufferNode::shape, BufferNode::strides, and BufferNode::elem_offset member variables are used to generate runtime checks on the corresponding member variables in the user-provided DLTensor* or tvm.nd.array argument. (e.g. A PrimFunc that accepts a buffer of shape [16,32] validates that the DLTensor::shape array is [16,32].)

For dynamic Buffers, in which one or more of these BufferNode member variables use tir.Var that are not defined by other PrimFunc parameters, these are instead used to define the variables based on the corresponding DLTensor members. (e.g. A PrimFunc that accepts a buffer of shape [tir.Var(“n”), tir.Var(“m”)], when passed a DLTensor of shape [16,32], will define n = 16 and n=32, based on the argument’s shape.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.MakeUnpackedAPI()[源代码]#

Transform the PrimFuncs in the module to a C API compatible with internal calls.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate T* arguments (e.g. float32*) that can be directly called by a C API.

For static shapes, no runtime validation is performed to confirm that the argument buffer’s shape matches the expected shape. For dynamic shapes, MakeUnpackedAPI requires that the dynamic parameters be passed as separate tir.Var parameters.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ManifestSharedMemoryLocalStage()[源代码]#

Add the explicit local stage for the shared memory access on GPU.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.MergeSharedMemoryAllocations()[源代码]#

This pass merges multiple TIR-level shared memory allocations into one allocation.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.NarrowDataType(target_bits)[源代码]#

Narrow down PrimExpr datatype in stmt to target_bits.

Parameters#

target_bitsint

The target bit configuration.

Returns#

fpasstvm.transform.Pass

The result pass

Note#

Run this pass after StorageFlatten.

参数:

target_bits (int)

tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()[源代码]#

Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.PointerValueTypeRewrite()[源代码]#

Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ReduceBranchingThroughOvercompute()[源代码]#

Reduce branching by introducing overcompute

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.RemoveAssume()[源代码]#

Remove all instances of builtin::assume

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.RemoveNoOp()[源代码]#

Remove No Op from the Stmt.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.RemoveStoreUndef()[源代码]#

Remove stores of undefined values from the Stmt.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=False)[源代码]#

Remove weight layout rewrite block before benchmarking during tuning stage.

Parameters#

skip_ndarray_rewritebool

If True, exact rewrite of NDArray, according to the given index map, will be skipped. Only the shape of the NDArray is transformed correctly, and the content of the destination array will be filled with random values.

When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap’s MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.RenormalizeSplitPattern()[源代码]#

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.RewriteUnsafeSelect()[源代码]#

Detect and rewrite unsafe select that contains memory access.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.Simplify()[源代码]#

Run arithmetic simplifications on the statements and expressions.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.SkipAssert()[源代码]#

Skip assert stmt.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.SplitHostDevice()[源代码]#

Split the function into a host function and device functions.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.StorageFlatten(cache_line_size, create_bound_attribute=False)[源代码]#

Flatten the multi-dimensional read/write to 1D.

Parameters#

cache_line_size: int

The size of CPU cache line.

create_bound_attribute:

Whether to create bound attributes.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

create_bound_attribute (bool)

tvm.tir.transform.StorageRewrite()[源代码]#

Rewrite storage allocation pattern.

Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.TextureFlatten()[源代码]#

Flatten the multi-dimensional read/write to 2D.

Parameters#

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.ThreadSync(storage_scope)[源代码]#

Insert sync between parallel read/write of shared buffers.

Parameters#

storage_scope: str

The target storage scope.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

storage_scope (str)

tvm.tir.transform.TransformMmaBufferLayout()[源代码]#

Transform mma buffer layout

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.UnifyThreadBinding()[源代码]#

Unify all the thread bindings for “blockIdx.x/y/z”, “threadIdx.x/y/z”, and “vthread.x/y/z”. Before the unification, two vars that are bound to a thread axis (e.g., “threadIdx.x”) use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them.

Returns#

fpasstvm.transform.Pass

The result pass

Note#

vthread is a legacy behavior that will be deprecated, though thread bindings of vthread are still also unified in this pass. Please use vthread.x, vthread.y and vthread.z instead.

tvm.tir.transform.UnrollLoop()[源代码]#

Unroll the constant loop marked by unroll.

This pass also automatically attach pragma unroll tag to loops which meets the standard.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.VectorizeLoop(enable_vectorize=True)[源代码]#

Lower vectorization loops.

Parameters#

enable_vectorizebool

Whether vectorization is enabled. Will lower to scalar loop when it is turned off.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

enable_vectorize (bool)

tvm.tir.transform.VerifyMemory()[源代码]#

Verify if func contains illegal host side direct memory access.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.transform.VerifyVTCMLimit(limit)[源代码]#

Verify if the size of the allocated vtcm memory satisfies the limit.

Returns#

fpasstvm.transform.Pass

The result pass

参数:

limit (int)

tvm.tir.transform.prim_func_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)[源代码]#

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters#

pass_funcOptional[Callable[(tvm.tir.PrimFunc, IRModule, PassContext) -> tvm.tir.PrimFunc]]

The transformation function or class.

opt_levelint

The optimization level of this module pass.

nameOptional[str]

The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

requiredOptional[List[str]]

The list of passes that the function pass is dependent on.

Returns#

create_function_pass : Union[Callable, FunctionPass]

A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Examples#

The following code block decorates a function pass class.

@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

The following code creates a function pass by decorating a user defined transform function.

@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the following:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
参数:
返回类型:

Callable | PrimFuncPass

tvm.tir.analysis#

Namespace of all TIR analysis utils.

Classes:

Block(iter_vars, reads, writes, name_hint, body)

Block node.

Buffer()

Symbolic data buffer in TVM.

BufferRegion(buffer, region)

BufferRegion node.

IRModule([functions, type_definitions, ...])

IRModule that holds functions and type definitions.

Object()

Base class for all tvm's runtime objects.

PrimExpr()

Base class of all primitive expressions.

PrimFunc(params, body[, ret_type, ...])

A function declaration expression.

Stmt()

Base class of all the statements.

Var(name, dtype[, span])

Symbolic variable.

Functions:

OOBChecker()

Detect out of bounds memory access in arrays.

apply_prim_func_arg_and_result_memory_constraints(...)

Returns func written to capture the memory (aka storage) scope constraints for each of the func's parameters given by arg_and_result_memory_scopes.

assert_pure_function(func)

Asserts that the function is a pure function

calculate_allocated_bytes(func_or_mod)

Calculate allocated memory per memory scope required by TIR PrimFuncs.

calculate_constant_bytes(func, ...)

Calculate the constant size in bytes needed by the TIR allocates inside the TIR PrimFunc.

calculate_workspace_bytes(func, ...)

Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.

detect_buffer_access_lca(func)

Detect the lowest common ancestor(LCA) of buffer access, including both high-level access (BufferLoad, BufferStore) and low-level access (BufferLoad, BufferStore and opaque access).

estimate_tir_flops(stmt_or_mod)

Estimate the FLOPs of a TIR fragment.

expr_deep_equal(lhs, rhs)

Deeply compare two nested expressions.

find_anchor_block(mod)

Find the "anchor block" of the given module.

get_block_access_region(block, buffer_var_map)

Detect which regions of tensors in this block are read or written to.

get_block_read_write_region(block, ...)

Auto detect the block read/write region according to its body stmt.

get_prim_func_arg_and_result_memory_constraints(...)

Returns the memory (aka storage) scope constraints for all the arguments and result of func.

get_vtcm_compaction_passes()

Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size

is_pure_function(func)

Checks if the function is a pure function

undefined_vars(node[, defs])

Find undefined vars in a TIR statement or expression.

verify_gpu_code(func, constraints)

Verify if module contains illegal host side direct memory access.

verify_memory(func)

Verify if func contains illegal host side direct memory access.

verify_ssa(func)

Verify if the func is in SSA form.

verify_well_formed(obj[, assert_mode])

Verify if the given TIR is well-formed. The verification includes:

class tvm.tir.analysis.Block(iter_vars, reads, writes, name_hint, body, init=None, alloc_buffers=None, match_buffers=None, annotations=None, span=None)[源代码]

Block node.

Parameters#

iter_varsList[IterVar]

The block Variable.

readsList[BufferRegion]

The read buffer regions of the block.

writes: List[BufferRegion]

The write buffer regions of the block.

name_hint: str

the name_hint of the block.

body: Stmt

The body of the block.

init: Optional[Stmt]

The init block of the reduction block

alloc_buffers: Optional[list[Buffer]]

The buffer allocations

match_buffers: Optional[List[MatchBufferRegion]]

The subregion buffer match

annotations: Optional[Mapping[str, Object]]

Additional annotation hints.

spanOptional[Span]

The location of this block in the source code.

参数:
class tvm.tir.analysis.Buffer[源代码]

Symbolic data buffer in TVM.

Buffer provide a way to represent data layout specialization of data structure in TVM.

Do not construct directly, use decl_buffer() instead. See the documentation of decl_buffer() for more details.

See Also#

decl_buffer : Declare a buffer

Methods:

access_ptr(access_mask[, ptr_type, ...])

Get an access pointer to the head of buffer.

get_flattened_buffer()

Generate a Buffer that is a flattened version of this buffer.

offset_of(indices)

Determine the offset of the provided indices in the flattened buffer.

scope()

Return the storage scope associated with this buffer. Returns ------- scope : str The storage scope associated with this buffer.

vload(begin[, dtype])

Generate an Expr that loads dtype from begin index.

vstore(begin, value)

Generate a Stmt that store value into begin index.

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)[源代码]

Get an access pointer to the head of buffer.

This is the recommended method to get buffer data ptress when interacting with external functions.

Parameters#

access_maskint

The access pattern MASK. Indicate whether the access will read or write to the data content.

ptr_typestr, optional

The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.

content_lanes: int, optional

The number of lanes for the data type. This value is greater than one for vector types.

offset: Expr, optional

The offset of pointer. We can use it to offset by the number of elements from the address of ptr.

extent: Expr, optional

The extent of pointer.

Examples#

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
get_flattened_buffer()[源代码]

Generate a Buffer that is a flattened version of this buffer.

Returns#

flattenedBuffer

The corresponding flat buffer.

offset_of(indices)[源代码]

Determine the offset of the provided indices in the flattened buffer.

Parameters#

indices : Union[PrimExpr, List[PrimExpr]]

The indices of the element in the original buffer.

Returns#

flattened_indices: List[PrimExpr]

The offset indices of the element in the flattened buffer.

scope()[源代码]

Return the storage scope associated with this buffer. Returns ——- scope : str

The storage scope associated with this buffer.

vload(begin, dtype=None)[源代码]

Generate an Expr that loads dtype from begin index.

Parameters#

beginArray of Expr

The beginning index in unit of Buffer.dtype

dtypestr

The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype

Returns#

loadExpr

The corresponding load expression.

vstore(begin, value)[源代码]

Generate a Stmt that store value into begin index.

Parameters#

beginArray of Expr

The beginning index in unit of Buffer.dtype

valueExpr

The value to be stored.

Returns#

storeStmt

The corresponding store stmt.

class tvm.tir.analysis.BufferRegion(buffer, region)[源代码]

BufferRegion node.

Parameters#

bufferBuffer

The buffer of the buffer region

regionList[Range]

The region array of the buffer region

参数:
class tvm.tir.analysis.IRModule(functions=None, type_definitions=None, attrs=None, global_infos=None)[源代码]

IRModule that holds functions and type definitions.

IRModule is the basic unit for all IR transformations across the stack.

Parameters#

functions: Optional[dict].

Map of global var to BaseFunc

Methods:

__getitem__(var)

Lookup a global definition by name or by variable.

__setitem__(var, val)

Add a mapping to the module.

astext([show_meta_data, annotate])

Get the text format of the expression.

from_expr(expr[, functions, type_defs])

Construct a module from a standalone expression.

functions_items()

Get items in self.functions.items() in alphabetical order.

get_attr(attr_key)

Get the IRModule attribute.

get_constructor(tag)

Look up an ADT constructor by tag.

get_global_type_var(name)

Get a global type variable in the function by name.

get_global_type_vars()

Collect all global type vars defined in this module.

get_global_var(name)

Get a global variable in the function by name.

get_global_vars()

Collect all global vars defined in this module.

update(other)

Insert functions in another Module to current one.

update_func(var, func)

Update the function corresponding to a global variable in the module.

update_global_info(name, global_info)

Update global info in the module

with_attr(attr_key, attr_value)

Copy the IRModule and add an attribute to it.

with_attrs(attr_map)

Copy the IRModule and add the given attribute map to it. Parameters ---------- attr_map: Union[DictAttrs, Dict[str, Object]] The attribute map Returns ------- mod : IRModule A new copy of the IRModule with the attribute.

without_attr(attr_key)

Copy the IRModule and remove an attribute key and its associated value. Parameters ---------- attr_key : str The attribute key. Returns ------- mod : IRModule A new copy of the IRModule without the attribute.

__getitem__(var)[源代码]

Lookup a global definition by name or by variable.

Parameters#

var: Union[String, GlobalVar, GlobalTypeVar]

The name or global variable.

Returns#

val: Union[Function, Type]

The definition referenced by var (either a function or type).

__setitem__(var, val)[源代码]

Add a mapping to the module.

Parameters#

var: GlobalVar

The global variable.

val: Union[Function, Type]

The value.

astext(show_meta_data=True, annotate=None)[源代码]

Get the text format of the expression.

Parameters#

show_meta_databool

Whether to include meta data section in the text if there is meta data.

annotate: Optional[Object->str]

Optionally annotate function to provide additional information in the comment block.

Returns#

textstr

The text format of the expression.

Notes#

The meta data section is necessary to fully parse the text format. However, it can contain dumps that are big (e.g constant weights), so it can be helpful to skip printing the meta data section.

static from_expr(expr, functions=None, type_defs=None)[源代码]

Construct a module from a standalone expression.

Parameters#

expr: RelayExpr

The starting expression

global_funcs: Optional[dict]

Map of global vars to function definitions

type_defs: Optional[dict]

Map of global type vars to type definitions

Returns#

mod: Module

A module containing the passed definitions, where expr is set as the entry point (wrapped in a function if necessary)

functions_items()[源代码]

Get items in self.functions.items() in alphabetical order.

Returns#

items: List[Tuple[GlobalVar, Function]]

The functions items.

get_attr(attr_key)[源代码]

Get the IRModule attribute.

Parameters#

attr_keystr

The attribute key.

Returns#

attr_valueAny

Attribute value

get_constructor(tag)[源代码]

Look up an ADT constructor by tag.

Parameters#

tag: int

The tag for a constructor.

Returns#

constructor: Constructor

The constructor associated with the given tag,

Raises#

tvm.error.TVMError if the corresponding constructor cannot be found.

get_global_type_var(name)[源代码]

Get a global type variable in the function by name.

Parameters#

name: str

The name of the global type variable.

Returns#

global_type_var: GlobalTypeVar

The global variable mapped to name.

Raises#

tvm.error.TVMError if we cannot find corresponding global type var.

get_global_type_vars()[源代码]

Collect all global type vars defined in this module.

Returns#

global_type_vars: Array[GlobalTypeVar]

An array of global type vars.

get_global_var(name)[源代码]

Get a global variable in the function by name.

Parameters#

name: str

The name of the global variable.

Returns#

global_var: GlobalVar

The global variable mapped to name.

Raises#

tvm.error.TVMError if we cannot find corresponding global var.

get_global_vars()[源代码]

Collect all global vars defined in this module.

Returns#

global_vars: Array[GlobalVar]

An array of global vars.

update(other)[源代码]

Insert functions in another Module to current one.

Parameters#

other: IRModule

The module to merge into the current Module.

update_func(var, func)[源代码]

Update the function corresponding to a global variable in the module.

Parameters#

var: GlobalVar

The global variable.

func: tvm.relay.Function

The function to be inserted.

update_global_info(name, global_info)[源代码]

Update global info in the module

Parameters#

name: str

The name for the global info.

global_info: List[GlobalInfo]

The global info to be updated.

with_attr(attr_key, attr_value)[源代码]

Copy the IRModule and add an attribute to it.

Parameters#

attr_keystr

The attribute key.

attr_valueObject

The new attribute value.

Returns#

modIRModule

A new copy of the IRModule with the attribute

with_attrs(attr_map)[源代码]

Copy the IRModule and add the given attribute map to it. Parameters ———- attr_map: Union[DictAttrs, Dict[str, Object]]

The attribute map

Returns#

modIRModule

A new copy of the IRModule with the attribute

参数:

attr_map (DictAttrs | Dict[str, Object])

返回类型:

IRModule

without_attr(attr_key)[源代码]

Copy the IRModule and remove an attribute key and its associated value. Parameters ———- attr_key : str

The attribute key.

Returns#

modIRModule

A new copy of the IRModule without the attribute

参数:

attr_key (str)

返回类型:

IRModule

class tvm.tir.analysis.Object[源代码]

Base class for all tvm’s runtime objects.

Methods:

_move()

Create an RValue reference to the object and mark the object as moved.

_move()[源代码]

Create an RValue reference to the object and mark the object as moved.

This is a advanced developer API that can be useful when passing an unique reference to an Object that you no longer needed to a function.

A unique reference can trigger copy on write optimization that avoids copy when we transform an object.

Note#

All the reference of the object becomes invalid after it is moved. Be very careful when using this feature.

Examples#

x = tvm.tir.Var("x", "int32")
x0 = x
some_packed_func(x._move())
# both x0 and x will points to None after the function call.

Returns#

rvalue : The rvalue reference.

class tvm.tir.analysis.PrimExpr[源代码]

Base class of all primitive expressions.

PrimExpr is used in the low-level code optimizations and integer analysis.

class tvm.tir.analysis.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)[源代码]

A function declaration expression.

Parameters#

params: List[Union[tvm.tir.Var, tvm.tir.Buffer]]

List of input parameters to the function.

body: tvm.tir.Stmt

The body of the function.

ret_type: tvm.ir.Type

The return type annotation of the function.

buffer_mapMap[tvm.tir.Var, tvm.tir.Buffer]

The buffer binding map.

attrs: Optional[tvm.Attrs]

Attributes of the function, can be None

spanOptional[Span]

The location of this itervar in the source code.

Methods:

specialize(param_map)

Specialize parameters of PrimFunc

with_body(new_body[, span])

Create a new PrimFunc with the same set signatures but a new body.

specialize(param_map)[源代码]

Specialize parameters of PrimFunc

Parameters#

param_mapMapping[Var, Union[PrimExpr, Buffer]]

The mapping from function params to the instance

Examples#

We can define a Meta TIR function with symbolic shape:

@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
    A = T.match_buffer(a, (m, n), "float32")
    B = T.match_buffer(b, (m, n), "float32")

    for i, j in T.grid(m, n):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

The specialized function:

@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    B = T.match_buffer(b, (16, 16), "float32")

    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Returns#

funcPrimFunc

The new function with parameter specialized

参数:

param_map (Mapping[Var, PrimExpr | Buffer])

with_body(new_body, span=None)[源代码]

Create a new PrimFunc with the same set signatures but a new body.

Parameters#

new_bodyStmt

The new body.

spanOptional[Span]

The location of this itervar in the source code.

Returns#

new_funcPrimFunc

The created new function.

参数:

span (Span | None)

class tvm.tir.analysis.Stmt[源代码]

Base class of all the statements.

class tvm.tir.analysis.Var(name, dtype, span=None)[源代码]

Symbolic variable.

Parameters#

namestr

The name

dtypeUnion[str, ir.Type]

The data type

spanOptional[Span]

The location of this expression in the source code.

参数:
tvm.tir.analysis.OOBChecker()[源代码]

Detect out of bounds memory access in arrays.

Returns#

fpasstvm.transform.Pass

The result pass

tvm.tir.analysis.apply_prim_func_arg_and_result_memory_constraints(func, relay_func_type, arg_and_result_memory_scopes)[源代码]

Returns func written to capture the memory (aka storage) scope constraints for each of the func’s parameters given by arg_and_result_memory_scopes. However, arg_and_result_memory_scopes should be w.r.t. the func’s representation as a Relay Function of relay_func_type before lowering and conversion to DPS.

Visible for testing.

CAUTION: This is experimental. The resulting PrimFunc may not have fully accounted for all new memory scopes.

Parameters#

func: tvm.tir.PrimFunc

The function to retrieve constraints from.

relay_func_type: tvm.relay.FuncType

The type of the Relay Function from which the func was derived.

arg_and_result_memory_scopes: Array[AnyStr]

Memory constraints for funcs args and result in Relay form. The empty string denotes ‘no constraint’.

Returns#

result: tvm.tir.PrimFunc

The rewritten func.

参数:
返回类型:

PrimFunc

tvm.tir.analysis.assert_pure_function(func)[源代码]

Asserts that the function is a pure function

参数:

func (PrimFunc)

返回类型:

bool

tvm.tir.analysis.calculate_allocated_bytes(func_or_mod)[源代码]

Calculate allocated memory per memory scope required by TIR PrimFuncs.

Parameters#

func_or_mod: Union[PrimFunc, IRModule]

The function or module to be detected. If a module is passed, allocated memory is calculated for all PrimFuncs inside the module

Returns#

resultUnion[Dict[str, int], Dict[str, Dict[str, int]]]

Allocated memory size per scope in bytes for each function in the IRModule returned as a dict with function names as keys and a dict of allocated sizes as values. If a single PrimFunc is passed, the function name is returned as “main”

参数:

func_or_mod (PrimFunc | IRModule)

返回类型:

Dict[str, int] | Dict[str, Dict[str, int]]

tvm.tir.analysis.calculate_constant_bytes(func, constant_byte_alignment)[源代码]

Calculate the constant size in bytes needed by the TIR allocates inside the TIR PrimFunc.

Parameters#

func: tvm.tir.PrimFunc

The function to be detected.

constant_byte_alignmentint

The byte alignment required for each tensor

Returns#

resultint

Workspace size in bytes.

参数:
返回类型:

int

tvm.tir.analysis.calculate_workspace_bytes(func, workspace_byte_alignment)[源代码]

Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.

Parameters#

func: tvm.tir.PrimFunc

The function to be detected.

workspace_byte_alignmentint

The byte alignment required for each tensor

Returns#

resultint

Workspace size in bytes.

参数:
返回类型:

int

tvm.tir.analysis.detect_buffer_access_lca(func)[源代码]

Detect the lowest common ancestor(LCA) of buffer access, including both high-level access (BufferLoad, BufferStore) and low-level access (BufferLoad, BufferStore and opaque access). The LCA may be a For loop or a Block.

Parameters#

func: tvm.tir.PrimFunc

The function to be detected.

Returns#

resultDict[Buffer, Stmt]

Map from buffer to the LCA of all access to it.

参数:

func (PrimFunc)

返回类型:

Dict[Buffer, Stmt]

tvm.tir.analysis.estimate_tir_flops(stmt_or_mod)[源代码]

Estimate the FLOPs of a TIR fragment.

Parameters#

stmt_or_mod: Union[Stmt, IRModule]

The TIR fragment or IRModule to be estimated.

Returns#

flops: float

The estimated FLOPs.

参数:

stmt_or_mod (Stmt | IRModule)

返回类型:

float

tvm.tir.analysis.expr_deep_equal(lhs, rhs)[源代码]

Deeply compare two nested expressions.

Parameters#

lhsPrimExpr

The left operand.

rhsPrimExpr

The right operand.

Returns#

resultbool

The comparison result

Note#

This function does not remap variable bindings, it will not return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y). Use py:func:tvm.ir.structural_equal to handle structural variable remapping.

Due to the restriction of not remapping variables, this function can run faster than StructuralEqual and can be used as a utility function during arithmetic simplifications.

Always consider py:func:tvm.ir.structural_equal first, which handles the structural remapping.

See Also#

tvm.ir.structural_equal

参数:
返回类型:

bool

tvm.tir.analysis.find_anchor_block(mod)[源代码]

Find the “anchor block” of the given module.

We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement.

For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input.

However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block.

Parameters#

mod: tvm.ir.IRModule

The input TIR module.

Returns#

anchor_block: Block

The anchor block if found, None otherwise.

参数:

mod (IRModule)

返回类型:

Block

tvm.tir.analysis.get_block_access_region(block, buffer_var_map)[源代码]
Detect which regions of tensors in this block are read or written to.

Regions are sorted by order of appearance in the AST.

Parameters#

block: tvm.tir.Block

The block in which we are detecting read/write regions.

buffer_var_mapDict[Var, Buffer]

The outside buffers which may access the block. Mapping from buffer var to the buffer

Returns#

resultList[List[BufferRegion]]
Array of access regions. There are three arrays of BufferRegion:
  • first: read regions

  • second: write regions

  • third: opaque regions

参数:
返回类型:

List[List[BufferRegion]]

tvm.tir.analysis.get_block_read_write_region(block, buffer_var_map)[源代码]
Auto detect the block read/write region according to its body stmt.

An opaque access will be counted as both a read and a write access

Parameters#

block: tvm.tir.Block

The block in which we are detecting read/write regions.

buffer_var_mapDict[Var, Buffer]

The outside buffers which may access the block. Mapping from buffer var to the buffer

Returns#

resultList[List[BufferRegion]]

An array only consisting of the read regions and write regions of the input block

参数:
返回类型:

List[List[BufferRegion]]

tvm.tir.analysis.get_prim_func_arg_and_result_memory_constraints(func, relay_func_type)[源代码]

Returns the memory (aka storage) scope constraints for all the arguments and result of func. However the result will be w.r.t. the func’s representation as a Relay Function of relay_func_type before lowering and conversion to DPS.

Visible for testing.

Parameters#

func: tvm.tir.PrimFunc

The function to retrieve constraints from.

relay_func_type: tvm.relay.FuncType

The type of the Relay Function from which the func was derived.

Returns#

result: List[AnyStr]

Memory scope constraints for funcs args and result in Relay form. The empty string denotes ‘no constraint’.

参数:
返回类型:

List[str]

tvm.tir.analysis.get_vtcm_compaction_passes()[源代码]

Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size

Returns#

resultList[tvm.transform.Pass]

returns list of passes

返回类型:

List[Pass]

tvm.tir.analysis.is_pure_function(func)[源代码]

Checks if the function is a pure function

参数:

func (PrimFunc)

返回类型:

bool

tvm.tir.analysis.undefined_vars(node, defs=None)[源代码]

Find undefined vars in a TIR statement or expression.

Parameters#

node: Union[Stmt, PrimExpr]

The TIR statement or expression to be checked.

defs: Optional[List[Var]]

The vars that is defined

Returns#

resultList[Var]

The undefined vars.

参数:
返回类型:

List[Var]

tvm.tir.analysis.verify_gpu_code(func, constraints)[源代码]

Verify if module contains illegal host side direct memory access.

Parameters#

func: tvm.tir.PrimFunc

The module to be verified.

constraintsDict[str, int]

The attribute constraints.

Returns#

resultbool

The result of verification.

参数:
返回类型:

None

tvm.tir.analysis.verify_memory(func)[源代码]

Verify if func contains illegal host side direct memory access.

Parameters#

func: tvm.tir.PrimFunc

The module to be verified.

Returns#

resultbool

The result of verification.

参数:

func (PrimFunc)

返回类型:

bool

tvm.tir.analysis.verify_ssa(func)[源代码]

Verify if the func is in SSA form.

Parameters#

func: tvm.tir.PrimFunc

The module to be verified.

Returns#

resultbool

The result of verification.

参数:

func (PrimFunc)

返回类型:

bool

tvm.tir.analysis.verify_well_formed(obj, assert_mode=True)[源代码]
Verify if the given TIR is well-formed. The verification includes:
  • Check if expressions not contain vars that is defined outside the block.

Parameters#

obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule]

The function or module to be verified.

assert_mode: bool

The indicator if it raises an error when the function is not well-formed.

Returns#

result: bool

Whether it is a well-formed TIR function.

参数:
返回类型:

bool

tvm.tir.stmt_functor#

Statement functor utilities for IR transformations

Functions:

ir_transform(stmt, preorder, postorder[, ...])

Recursively visit and transform ir nodes in post DFS order.

post_order_visit(stmt, fvisit)

Recursively visit the ir in post DFS order node, apply fvisit

pre_order_visit(stmt, fvisit)

Recursive pre-order visit on stmt AST, applying fvisit on each node.

renew_defs(func)

Re-generate the definition nodes for a TIR, including VarDef, BufferDef.

substitute(node, vmap)

Substitute the var specified by vmap.

tvm.tir.stmt_functor.ir_transform(stmt, preorder, postorder, only_enable=None)[源代码]#

Recursively visit and transform ir nodes in post DFS order.

Parameters#

stmttvm.tir.Stmt

The input to be transformed.

preorder: function

The function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. If preorder returns a not None tvm.tir.Stmt/Expr, the transformer will simply return it and won’t do further recursion.

postorderfunction

The function called after recursive mutation.

only_enableOptional[List[str]]

List of types that we only enable.

Returns#

resulttvm.tir.Stmt

The result.

tvm.tir.stmt_functor.post_order_visit(stmt, fvisit)[源代码]#
Recursively visit the ir in post DFS order node, apply fvisit

Each node is guaranteed to be visited only once.

Parameters#

fvisit: function

The visitor function.

tvm.tir.stmt_functor.pre_order_visit(stmt, fvisit)[源代码]#
Recursive pre-order visit on stmt AST, applying fvisit on each node.

If fvisit returns False, it won’t visit the children of the node.

Parameters#

fvisit: function of the signature Object -> bool

The visitor function.

tvm.tir.stmt_functor.renew_defs(func)[源代码]#

Re-generate the definition nodes for a TIR, including VarDef, BufferDef. This pass works as a simple DeepCopy to duplicate a function with different Vars and Buffers but the same behavior

Parameters#

func: PrimFunc

The input function

Returns#

resultPrimFunc

The new generated func.

参数:

func (PrimFunc)

tvm.tir.stmt_functor.substitute(node, vmap)[源代码]#

Substitute the var specified by vmap.

Parameters#

node: ObjectRef

The input.

vmapDict[Var, PrimExpr]

The variable mapping.

Returns#

resulttvm.tir.Stmt

The result.