tvm.tir.transform

目录

tvm.tir.transform#

Namespace of all TIR transformations

tvm.tir.transform.AnnotateDeviceRegions()[源代码]#

Annotate locations that should be run on the device

Insert AttrStmt nodes specifying a target on which regions within the PrimFunc should be executed. Only modifies functions that have a tvm::attr::kTarget attribute, and where that target defines a host.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.AnnotateEntryFunc()[源代码]#

Set a PrimFunc as the entry point if it is only function in IRModule.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.Apply(ftransform)[源代码]#

Apply ftransform to each function in the Module.

This function is a thin wrapper around tvm.tir.transform.prim_func_pass

参数:

ftransform (tvm.tir.PrimFunc -> tvm.tir.PrimFunc) -- The transformation pass.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ApplyLayoutTransforms()[源代码]#

Reshape buffers that appear in the "layout_transform_map" fucntion attribute.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.BF16ComputeLegalize()[源代码]#

Legalize bf16 compute Ops.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.BF16StorageLegalize()[源代码]#

Legalize bf16 storage types to u16.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.BindTarget(target)[源代码]#

Annotate a PrimFunc with a given target. :param target: target :type target: tvm.target.Target

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.CoProcSync()[源代码]#

Detect and insert sync points to co-processor.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.CombineContextCall()[源代码]#

Combine context calls in the host function.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False)[源代码]#

Replace redundant computations by new variables.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.CompactBufferAllocation(is_strict: bool = True)[源代码]#

Compact the buffer access region. by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.

示例

Before narrowing, B is a [16, 16] buffer, but only a skinny vector B[i, 0:16] is accessed.

for i in range(0, 16):
    with T.block():
        B = T.alloc_buffer(16, 16)
        for j in range(0, 16):
            B[i, j] = A[i, j] + 1
        for j in range(0, 16):
            C[i, j] = B[i, j] + 1

This pass narrows the buffer shape and adjust its accessed region accordingly. In this particular case, because only a 1 * 16 vector of B is accessed, the pass narrows B to shape [1, 16], and changes the access to B[i, j] to B[0, j].

for i in range(0, 16):
    with T.block():
        B = T.alloc_buffer(1, 16)
        for j in range(0, 16):
            B[0, j] = A[i, j] + 1
        for j in range(0, 16):
            C[i, j] = B[0, j] + 1
参数:

is_strict (bool) -- Ensure the compacted shape to be always smaller than the original shape. Otherwise it allows to grow the shape to match actual accessed buffer regions.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ConvertBlocksToOpaque()[源代码]#

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ConvertForLoopsToSerial()[源代码]#

Convert Parallel For Loops to Serial For Loops.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ConvertSSA()[源代码]#

Convert an IRModule to be SSA form.

This pass handles cases where the same tir.Var appears in multiple functions within the same module. For example, after extracting a fragment from one function into another, where the same tir.Var may be defined both as within the body of the original function, and as a parameter within the hoisted function.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.DecorateDeviceScope()[源代码]#

Decorate all the function's body as device function.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.DefaultGPUSchedule()[源代码]#

The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, allowing their build and execution on GPU devices. It examines all the blocks within the PrimFunc and conducts loop fusion, splitting, and reordering operation based on the loop extent and target information, such as the maximum thread block number and maximum thread per block.

The primary objective of this pass is not to optimize performance, but rather to generate a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass is currently only working for CUDA targets.

返回:

ret

返回类型:

tvm.transform.Pass

tvm.tir.transform.ExtractPrimFuncConstants()[源代码]#

Collects and unificates tir non-scalar constants to module's attr 'Constants' array.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.FP8ComputeLegalize(promote_dtype_str: str = 'float32')[源代码]#

Legalize fp8 compute Ops.

参数:

promote_dtype (str) -- The data type we promote fp8 to, options: float16/float32.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.FP8StorageLegalize()[源代码]#

Legalize fp8 storage types to u8.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.Filter(fcond: Callable)[源代码]#

Filter out PrimFuncs that does not satisfy the given condition. fcond should be a function that takes a primfunc and returns boolean.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.FlattenBuffer()[源代码]#

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ForceNarrowIndexToInt32()[源代码]#

Force narrow down indexing expressions and integer buffers to int32 dtype.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

备注

This pass should not be used in default cases.

tvm.tir.transform.HoistExpression()[源代码]#

Generalized verison of HoistIfThenElse.

Hoist loop-invariant expressions to outside the eligible loops. Searches for expressions in:

  • LetStmt bindings

  • IfThenElse conditions

  • Boolean operators

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.HoistIfThenElse(variant: str | None = None)[源代码]#

Hoist loop-invariant IfThenElse nodes to outside the eligible loops.

参数:

variant (Optional[String]) --

The variant of the pass. variant can have any one of following values ["basic", None(Default)].

The basic variant supports basic hoisting scenarios where it expects the For & If Nodes are in place consecutively and does not involve global scope variables or more advanced scenarios.

Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"} supported with control with PassContext configs like below:

config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

class tvm.tir.transform.HoistedConditionals(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#

Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be hoisted to the outermost loop possible.

All = 15#

Enable all hoisting of conditionals

BooleanExpression = 4#

If set, look for hoist candidates in all boolean expressions

IfElseExpr = 2#

If set, look for hoist candidates in tir.if_then_else

IfElseStmt = 1#

If set, look for hoist candidates in IfElseStmt

Never = 0#

No hoisting of conditionals

UsingBlockVar = 8#

If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x)

class tvm.tir.transform.HoistedLetBindings(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#

Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be hoisted to the outermost loop possible.

All = 7#

Enable all hoisting of let bindings

LetExpr = 4#

Bindings occuring in Let expressions

LetStmt = 2#

Bindings occuring in LetStmt

Never = 0#

No hoisting of let bindings

RequiredByConditional = 1#

Bindings that are used by a hoisted conditional

tvm.tir.transform.InferFragment()[源代码]#

Infer the TensorCore fragment infomation using tensor intrinsics.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectCopyIntrin(pragma_key: str, fintrin)[源代码]#

Inject virtual thread loops.

参数:
  • pragma_key (str) -- The pragma key for hint of copy.

  • fintrin (function) -- The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectDoubleBuffer()[源代码]#

Inject double buffer statements.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectPTXAsyncCopy()[源代码]#

Rewrite global to shared memory copy on CUDA with asyncronous copy.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectPermutedLayout()[源代码]#

Inject permuted layout in mma

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectPrefetch()[源代码]#

Inject prefetch instructions into stmt.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectRollingBuffer()[源代码]#

Inject rolling buffer statements.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectSoftwarePipeline()[源代码]#

Transform annotated loops into pipelined one that parallelize producers and consumers

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InjectVirtualThread()[源代码]#

Inject virtual thread loops.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InlinePrivateFunctions()[源代码]#

Inline calls to private functions

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InstallDebugSpans()[源代码]#

Add line information from the TIR printer as spans on each statement and expression.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InstrumentBoundCheckers()[源代码]#

Instruments bound checkers.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.InstrumentProfileIntrinsics()[源代码]#

Insert intrinsic calls to instrument function and loop level profiling.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LegalizePackedCalls()[源代码]#

Legalize packed calls to have its arguments wrapped in TVMValues

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LiftAttrScope(attr_key: str)[源代码]#

Lift common attrs with attr_key to outer scope.

参数:

attr_key (str) -- The attribute key to be checked.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LiftThreadBinding()[源代码]#

Lift the same thread bindings to their LCA loops.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LoopPartition()[源代码]#

Inject virtual thread loops.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerAutoCopy()[源代码]#

Automatically do memory optimizations for auto copy blocks

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerCrossThreadReduction()[源代码]#

Lower cross-thread reduction from thread bindings to intrinsic function calls.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerCustomDatatypes()[源代码]#

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerDeviceKernelLaunch()[源代码]#

Lower cross-device function calls.

Prior to this pass, host to device calls are represented as subroutine calls, with environment parameters (e.g. env_thread) specified internally. The device function is an internal function, without a tvm::attr::kGlobalSymbol attribute.

After this pass, host to device calls are represented as tvm_call_packed built-in. The device function is an externally-exposed function, with a non-empty tvm::attr::kGlobalSymbol attribute.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerDeviceStorageAccessInfo()[源代码]#

Lower attached storage access information on device.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

备注

Run this pass after all storage access analysis finish.

tvm.tir.transform.LowerInitBlock()[源代码]#

Lower block init stmt into IfThenElse statements.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerIntrin()[源代码]#

Lower target specific intrinsic calls.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerMatchBuffer()[源代码]#

Remove match buffers inside the block. Also, it will validate the binding.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerOpaqueBlock()[源代码]#

Remove the block to ensure that the TIR can not be scheduled again.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerTVMBuiltin()[源代码]#

Lower tvm builtin intrinsics.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerThreadAllreduce()[源代码]#

Lower cross thread alleduce.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.LowerWarpMemory()[源代码]#

Lower warp memory access to low-level device related function calls.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.MakePackedAPI()[源代码]#

Transform the PrimFuncs in the module to a packed func API.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate TVMArgs and TVMRetValue* arguments that implement the PackedFunc API.

For static shapes, the BufferNode::shape, BufferNode::strides, and BufferNode::elem_offset member variables are used to generate runtime checks on the corresponding member variables in the user-provided DLTensor* or tvm.nd.array argument. (e.g. A PrimFunc that accepts a buffer of shape [16,32] validates that the DLTensor::shape array is [16,32].)

For dynamic Buffers, in which one or more of these BufferNode member variables use tir.Var that are not defined by other PrimFunc parameters, these are instead used to define the variables based on the corresponding DLTensor members. (e.g. A PrimFunc that accepts a buffer of shape [tir.Var("n"), tir.Var("m")], when passed a DLTensor of shape [16,32], will define n = 16 and n=32, based on the argument's shape.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.MakeUnpackedAPI()[源代码]#

Transform the PrimFuncs in the module to a C API compatible with internal calls.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate T* arguments (e.g. float32*) that can be directly called by a C API.

For static shapes, no runtime validation is performed to confirm that the argument buffer's shape matches the expected shape. For dynamic shapes, MakeUnpackedAPI requires that the dynamic parameters be passed as separate tir.Var parameters.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ManifestSharedMemoryLocalStage()[源代码]#

Add the explicit local stage for the shared memory access on GPU.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.MergeSharedMemoryAllocations()[源代码]#

This pass merges multiple TIR-level shared memory allocations into one allocation.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.NarrowDataType(target_bits: int)[源代码]#

Narrow down PrimExpr datatype in stmt to target_bits.

参数:

target_bits (int) -- The target bit configuration.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

备注

Run this pass after StorageFlatten.

tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()[源代码]#

Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.PointerValueTypeRewrite()[源代码]#

Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

class tvm.tir.transform.PrimFuncPass[源代码]#

A pass that works on each tvm.tir.PrimFunc() in a module. A function pass class should be created through py:func:tvm.tir.transform.function_pass.

tvm.tir.transform.ReduceBranchingThroughOvercompute()[源代码]#

Reduce branching by introducing overcompute

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.RemoveAssume()[源代码]#

Remove all instances of builtin::assume

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.RemoveNoOp()[源代码]#

Remove No Op from the Stmt.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.RemoveStoreUndef()[源代码]#

Remove stores of undefined values from the Stmt.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=False)[源代码]#

Remove weight layout rewrite block before benchmarking during tuning stage.

参数:

skip_ndarray_rewrite (bool) --

If True, exact rewrite of NDArray, according to the given index map, will be skipped. Only the shape of the NDArray is transformed correctly, and the content of the destination array will be filled with random values.

When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.RenormalizeSplitPattern()[源代码]#

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.RewriteUnsafeSelect()[源代码]#

Detect and rewrite unsafe select that contains memory access.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.Simplify()[源代码]#

Run arithmetic simplifications on the statements and expressions.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.SkipAssert()[源代码]#

Skip assert stmt.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.SplitHostDevice()[源代码]#

Split the function into a host function and device functions.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.StorageFlatten(cache_line_size, create_bound_attribute: bool = False)[源代码]#

Flatten the multi-dimensional read/write to 1D.

参数:
  • cache_line_size (int) -- The size of CPU cache line.

  • create_bound_attribute -- Whether to create bound attributes.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.StorageRewrite()[源代码]#

Rewrite storage allocation pattern.

Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.TextureFlatten()[源代码]#

Flatten the multi-dimensional read/write to 2D.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.ThreadSync(storage_scope: str)[源代码]#

Insert sync between parallel read/write of shared buffers.

参数:

storage_scope (str) -- The target storage scope.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.TransformMmaBufferLayout()[源代码]#

Transform mma buffer layout

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.UnifyThreadBinding()[源代码]#

Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., "threadIdx.x") use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

备注

vthread is a legacy behavior that will be deprecated, though thread bindings of vthread are still also unified in this pass. Please use vthread.x, vthread.y and vthread.z instead.

tvm.tir.transform.UnrollLoop()[源代码]#

Unroll the constant loop marked by unroll.

This pass also automatically attach pragma unroll tag to loops which meets the standard.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.UseAssumeToReduceBranches()[源代码]#

This pass attempts to eliminates layout specific pad branch by overcomputing the values for padded region. Eliminating the branch will help to vectorize code, and improve element wise ops performance.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.VectorizeLoop(enable_vectorize: bool = True)[源代码]#

Lower vectorization loops.

参数:

enable_vectorize (bool) -- Whether vectorization is enabled. Will lower to scalar loop when it is turned off.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.VerifyMemory()[源代码]#

Verify if func contains illegal host side direct memory access.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.VerifyVTCMLimit(limit: int)[源代码]#

Verify if the size of the allocated vtcm memory satisfies the limit.

返回:

fpass -- The result pass

返回类型:

tvm.transform.Pass

tvm.tir.transform.prim_func_pass(pass_func=None, opt_level: int = None, name: str | None = None, required: List[str] | None = None, traceable=False) Callable | PrimFuncPass[源代码]#

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

参数:
  • pass_func (Optional[Callable[(tvm.tir.PrimFunc, IRModule, PassContext) -> tvm.tir.PrimFunc]]) -- The transformation function or class.

  • opt_level (int) -- The optimization level of this module pass.

  • name (Optional[str]) -- The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) -- The list of passes that the function pass is dependent on.

返回:

create_function_pass -- A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

返回类型:

Union[Callable, FunctionPass]

示例

The following code block decorates a function pass class.

@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

The following code creates a function pass by decorating a user defined transform function.

@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the following:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.