tvm.relax.transform

目录

tvm.relax.transform#

Relax transformations.

class tvm.relax.transform.AttachExternModules(extern_modules)[源代码]#

Attach variable bounds to each Relax function, which primarily helps with memory planning.

参数:

extern_modules (List[ExternModule])

class tvm.relax.transform.DataflowBlockPass[源代码]#

A pass that works on each tvm.relax.DataflowBlock in a module.

class tvm.relax.transform.FastMathTransform(*args, **kwargs)[源代码]#

Pass to convert the expensive non linear functions to their fast but approximate counterparts.

class tvm.relax.transform.FunctionPass[源代码]#

A pass that works on each tvm.relax.Function in a module. A function pass class should be created through function_pass.

class tvm.relax.transform.FuseTransposeMatmul(*args, **kwargs)[源代码]#

A compiler pass that fuses transpose + matmul.

class tvm.relax.transform.FusionPattern(name, pattern, annotation_patterns=None, check=None, attrs_getter=None)[源代码]#

The pattern used by FuseOpsByPattern. It's mainly DFPattern but with other information to help during the fusion pass.

Parameters#

name: str

The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'.

pattern: DFPattern

The dataflow pattern that will be used to match expressions that can be handled by external backends.

annotation_patterns: Mapping[str, DFPattern]

The map which is used to extract important expressions from the pattern match result. All DFPattern in this map should be part of the pattern.

check: Callable[[PatternCheckContext], bool]

The function to check whether the match result is accepted.

参数:
class tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy)[源代码]#

Rewrite all-reduce operation to customized all-reduce impl with IPC memory.

参数:

allreduce_strategy (int)

__init__(allreduce_strategy)[源代码]#

Constructor

Parameters#

allreduce_strategyint

The all-reduce strategy. Only "1" and "2" are supported. "1" stands for one-shot, and "2" stands for two-shot.

参数:

allreduce_strategy (int)

返回类型:

None

class tvm.relax.transform.LazyTransformParams(fget_item='get_item', fset_item='set_item', extra_get_item_params=None, extra_set_item_params=None)[源代码]#

Convert transform_params functions into a lazy version. (Load the input to memory on demand, and immediately free it after the last use.)

Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass.

Parameters#

fget_item: str

The name of the get_item function.

fset_item: str

The name of the set_item function.

extra_get_item_params: list of relax.Var

The parameters of the get_item function except index. The given parameters will be placed before index. For example, if extra_get_item_params is [param1, param2], then the pass will generate call_packed(fget_item, [param1, param2, index])

extra_set_item_params: list of relax.Var

The parameters of the set_item function except index and value. The given parameters will be placed before index and value. For example, if extra_set_item_params is [param1, param2], then the pass will generate call_packed(fset_item, [param1, param2, index, value])

class tvm.relax.transform.LowerGPUIPCAllocStorage(*args, **kwargs)[源代码]#

Lower the storage/tensor allocation on IPC memory.

class tvm.relax.transform.OptimizeLayoutTransform[源代码]#

Pass to remove redundant transform layout operators introduced by AlterOpImpl pass.

class tvm.relax.transform.PatternCheckContext[源代码]#

The input of check function FusionPattern.check.

Parameters#

matched_expr: Expr

The expression that's matched with the FusionPattern.pattern.

annotated_expr: Mapping[str, Expr]

A map which contains all expressions matched by the sub patterns in FusionPattern.annotation_patterns.

matched_bindings: Mapping[Var, Expr]

Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByPattern.

var_usages: Mapping[Var, Sequence[Var]]

A map mapping variable definitions to a set of uses. It has all variables used in the function.

value_to_bound_var: Mapping[Expr, Var]

Map from value to its bound variable. It doesn't have variables after the matched expression.

class tvm.relax.transform.RemoveRedundantReshape[源代码]#

Transformation pass to remove redundant reshape operator

tvm.relax.transform.AdjustMatmulOrder()[源代码]#

Reorder x*(A*B) to (x*A)*B

Useful for optimizing LoRA computations, where matmul(x, LoraA*LoraB) may be computed as matmul(matmul(x, LoraA), LoraB), reducing the total memory usage.

Returns#

rettvm.transform.Pass

The corresponding pass.

tvm.relax.transform.AllocateWorkspace()[源代码]#

Allocate a workspace, represented by a tensor of size big enough for all external functions that require a temporary storage, and append it to the arguments of external functions.

An external function can specify its workspace requirement by the kWorkspaceSize attribute.

Returns#

ret: tvm.ir.transform.Pass

The registered pass for allocating workspace.

返回类型:

Pass

tvm.relax.transform.AlterOpImpl(op_impl_map, op_buffer_transforms, op_buffer_axis_separators, op_buffer_input_axis_separators)[源代码]#

Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement PrimFunc that could possibly have different layouts on i/o buffers. The layout transformations on i/o buffers is present in the op_buffer_transforms map. Inserts the layout transformations in the call sites of PrimFuncs being replaced to transform i/o tensors into expected layout by new PrimFunc.

Parameters#

op_impl_map: Dict[str, PrimFunc]

op_kind to PrimFunc map

op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]

op_kind to layout transformation map for each of the buffers

op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]

op_kind to axis_separator for each index_map

op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]

op_kind to axis_separator for input index_map

Returns#

ret: tvm.ir.transform.Pass

参数:
tvm.relax.transform.AnnotateTIROpPattern()[源代码]#

Annotate Op Pattern Kind for TIR functions

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.AttachAttrLayoutFreeBuffers()[源代码]#

Attach layout free buffers to the tir::PrimFunc.

This pass is used to attach layout free buffers to the tir::PrimFunc according to the function usage in the relax function. Currently, the layout free buffers are the model weights and relax constants.

Note that we recommend applying CanonicalizeBindings before this pass.

Returns#

rettvm.transform.Pass

The registered pass for attaching layout free buffers.

返回类型:

Pass

tvm.relax.transform.AttachGlobalSymbol()[源代码]#

Attach global_symbol to Relax functions and TIR Primfuncs for codegen.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.BindParams(func_name, params)[源代码]#

Bind params of function of the module to constant tensors.

Parameters#

func_name: str

The function name to be bound

params: Dict[Union[str,relax.Var], Union[tvm.runtime.NDArray, np.ndarray]]

The map from parameter or parameter name to constant tensors.

Returns#

ret: tvm.ir.transform.Pass

参数:
返回类型:

Pass

tvm.relax.transform.BindSymbolicVars(binding_map, func_name=None)[源代码]#

Bind params of function of the module to constant tensors.

Parameters#

binding_mapMapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr]

The map from symbolic varname to integer.

func_nameOptional[str]

The function name to be bound. If None (default), all functions within the module will be updated.

Returns#

ret: tvm.ir.transform.Pass

参数:
返回类型:

Pass

tvm.relax.transform.BundleModelParams(param_tuple_name=None)[源代码]#

Bundle several model parameters into a single tuple paramters

For each function, if the function has the attribute "num_input", separate between run-time parameters and compile-time weights. Run-time parameters (e.g. activations) are the first num_input parameters, and the remainder are compile-time weights.

Parameters#

param_tuple_name: Optional[str]

The name of the tuple parameter. If unspecified, defaults to "model_params".

Returns#

rettvm.transform.Pass

The registered pass for bundling model parameters.

参数:

param_tuple_name (str | None)

返回类型:

Pass

tvm.relax.transform.CallTIRRewrite()[源代码]#

Perform explicit tensor allocation for call_tir and call_dps_packed.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.CanonicalizeBindings()[源代码]#

Canonicalizes variable definitions (e.g., if there is y = x and z = y, it replaces uses of y and z with x). Also simplifies match cast nodes (eliminating redundant checks) and tuple indices.

Best combined with constant folding and the elimination of unused definitions.

Note: If a dataflow var is used only in a binding to the dataflow block output var (i.e., a non-dataflow var), this pass will also remove the dataflow var and replaces the output var's binding with the dataflow var's direct definition.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.CombineParallelMatmul(check=None)[源代码]#

Combine multiple matmul operators sharing the same LHS matrix into one, followed by slicing. When all matmul branches in a tree have the same set of fused ops, the fused ops are applied to the combined matmul output before slicing.

Currently, only a limited set of fused ops is supported. It includes bias add, relu, gelu, gelu_tanh and silu activation.

Parameters#

checkCallable[[Var, List[Var], List[Var], Dict[Var, Expr]], bool]

A function to filter out unwanted branches, with the signature (input, [rhs], [bias], binding) -> bool.

Returns#

rettvm.transform.Pass

The corresponding pass.

tvm.relax.transform.ComputePrimValue()[源代码]#

Compute all R.prim_value instances

While high-level relax can include expressions in terms of its symbolic variables, these expressions cannot natively be computed within relax. In order to provide values for symbolic expressions (e.g. R.prim_value(N*N), where N is a symbolic variable), this pass generates a PrimFunc in which the expression can be computed. The relax graph is then updated to include a call to that PrimFunc, in place of the original R.prim_value(expr).

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.ConvertLayout(desired_layouts)[源代码]#

Automatic layout conversion pass.

Parameters#

desired_layoutsDict[str, List[str]]

The desired layout of conv2d ops is a map from the name of the op to the desired layout of the desired feature map, weight and output. For example, if we want to convert the layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be {"relax.nn.conv2d": ["NHWC", "OHWI"]}.

Returns#

rettvm.transform.Pass

The registered pass for layout conversion.

参数:

desired_layouts (Dict[str, List[str]])

返回类型:

Pass

tvm.relax.transform.ConvertToDataflow(min_size=2)[源代码]#

A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.

Note: ConvertToDataflow may need to be called first.

Parameters#

min_size: int

The minimum number of consecutive dataflow bindings the pass needs to extract a new block.

Returns#

ret: tvm.ir.transform.Pass

The pass.

参数:

min_size (int)

返回类型:

Pass

tvm.relax.transform.DataflowUseInplaceCalls()[源代码]#

Pass that changes calls to operators that can be done in-place (generally, these are elementwise operations) into in-place implementations. Supported operators will be replaced by calls to call_tir_inplace that invoke in-place PrimFunc implementations of those operators (which are based on the legalizations of those operators).

Note: ConvertToDataflow may need to be called first to provide dataflow blocks.

Returns#

ret: tvm.ir.transform.Pass

The pass

返回类型:

Pass

tvm.relax.transform.DeadCodeElimination(entry_functions=None)[源代码]#

Remove dead code in the IRModule. Currently it removes:

  1. Unused local VarBindings (those where the bound var is unused and no impure operation is used).

  2. Unused Relax functions in the module. We detect the call chain from the entry function, and remove all unused functions.

Any binding blocks that are left empty will be removed by the normalizer.

Notes#

For function-wise DCE, use py:func:tvm.relax.analysis.remove_all_unused.

Parameters#

entry_functions: Optional[List[str]]

The set of entry functions to start from.

Returns#

rettvm.transform.Pass

The registered pass.

参数:

entry_functions (List[str] | None)

返回类型:

Pass

tvm.relax.transform.DecomposeOpsForInference(func_name=None)[源代码]#

Decompose composite operators that are composed by other operators during inference. For example, the result of batch norm (a triple) will be simplified. Attention, tensor_to_shape, etc. can be also decomposed into a number of simplified operators as well.

Parameters#

func_name: Optional[str]

The name of the specified function. If not specified, the pass will run in all functions.

Returns#

rettvm.transform.Pass

The registered pass

参数:

func_name (str | None)

返回类型:

Pass

tvm.relax.transform.DecomposeOpsForTraining(func_name=None)[源代码]#

Decompose composite operators that are composed by other operators during training. For example, the result of batch norm (a triple) will be simplified. Attention, tensor_to_shape, etc. can be also decomposed into a number of simplified operators as well.

Parameters#

func_name: Optional[str]

The name of the specified function. If not specified, the pass will run in all functions.

Returns#

rettvm.transform.Pass

The registered pass

参数:

func_name (str | None)

返回类型:

Pass

tvm.relax.transform.EliminateCommonSubexpr(call_only=False)[源代码]#

Eliminate common subexpressions within functions.

Note: For nested functions, this pass performs CSE within those functions

Parameters#

call_onlybool

If True, enable eliminating only call nodes.

Returns#

rettvm.transform.Pass

The registered pass that eliminates common subexpressions.

返回类型:

FunctionPass

tvm.relax.transform.ExpandMatmulOfSum()[源代码]#

Expand matmul(x, A+B) to matmul(x,A) + matmul(x,B)

If either operand can be fully computed at compile-time (only depends on function parameters after kNumInput), this expansion is suppressed.

Useful for optimizing LoRA computations, where matmul(x, Base + LoraA*LoraB) may be expanded to matmul(x, Base) + matmul(x, LoraA*LoraB), allowing it to optimized with CombineParallelMatmul.

Returns#

rettvm.transform.Pass

The corresponding pass.

tvm.relax.transform.ExpandTupleArguments()[源代码]#

Expand tuple arguments to internal functions

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.FewShotTuning(valid_count=1, benchmark=False)[源代码]#

The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the blocks within the PrimFunc and conducts loop fusion, splitting, and other transformations based on MetaSchedule schedule rules but directly samples from the search space instead of using the tuning algorithm. User can specify the number of valid counts to try and whether to use runner for benchmarking.

Parameters#

valid_count: int

The number of valid counts to try.

benchmark: bool

Whether to use runner for benchmarking.

Returns#

ret: tvm.ir.transform.Pass

参数:
  • valid_count (int)

  • benchmark (bool)

返回类型:

Pass

tvm.relax.transform.FoldConstant()[源代码]#

Fold constant expressions within dataflow blocks.

Note: ConvertToDataflow may need to be called first to provide dataflow blocks.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.FuseOps(fuse_opt_level=-1)[源代码]#

This pass groups bindings in a dataflow block of Relax functions and generate a new grouped Relax function for each group, according to the fusion algorithm described in the pass implementation. By grouping bindings into new Relax functions, we substitute the bindings in the function being manipulated into function calls to the new grouped function.

A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.

Note: ConvertToDataflow may need to be called first to provide dataflow blocks.

Parameters#

fuse_opt_levelint

The level of fuse optimization. -1 indicates that the level will be inferred from pass context.

Returns#

rettvm.transform.Pass

The registered pass for operator fusion.

返回类型:

Pass

tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=False, entry_functions=None)[源代码]#

Apply pattern matching to each function in the given module, and group matched expressions into a new function.

The end result is similar to FuseOps, but fusion is driven completely by the provided patterns.

Note: Only operates within dataflow blocks. ConvertToDataflow may need to be called first.

Parameters#

patternsList[Union[FusionPattern, Tuple]]

A list of patterns to be matched. The order of the patterns determines the order of priority in which they are matched. Higher-priority patterns should come earlier in the list.

In addition to FusionPattern, a tuple can be passed as item of this list. The pattern will be constructed through FusionPattern(*item)

bind_constantsbool

Whether or not to keep bound constants in the grouped function.

annotate_codegenbool

If True, wrap each created composite function with another function, whose body consists only of a call to the composite function, and annotate the outer function with "Codegen" and "global_symbol" attributes. The "Codegen" attribute is set as the prefix of the corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu".

This must be True if the created composite functions are intended to be offloaded to an external backend without using the MergeCompositeFunctions pass.

entry_functionsOptional[List[str]]

The set of entry functions to start from.

Returns#

rettvm.transform.Pass

The registered pass for pattern-based fusion.

参数:
返回类型:

Pass

tvm.relax.transform.FuseTIR()[源代码]#

Fuse primitive relax function into a larger TIR function if possible

Returns#

rettvm.transform.Pass

The registered pass for tir fusion.

返回类型:

Pass

tvm.relax.transform.Gradient(func_name, require_grads=None, target_index=0)[源代码]#

Reverse-mode automatic differentiation.

This pass will differentiate one function in the IRModule. Now the input function must have only one dataflow block (ConvertToDataflow may need to be called first).

For a given function specified by func_name, it generates a new function with the name func_name + "_adjoint". The new function computes the gradient of the differentiation target with respect to the arguments specified by require_grads of the original function.

If the function has only one return value, the return value will be specified as target. If the function has more than one return values, the target will be specified as the target_index-th return value. The target must be a scalar (0-dim tensor).

The new function will be like:

@R.function
def main_adjoint(original_parameters):
    with R.dataflow():
        # the bindings of the original function
        ...
        # calculating the gradients
        ...
        R.output(original_outputs, grad_1, grad_2, ...)
    return (original_return_value, (grad_1, grad_2, ...))

This AD pass also supports checkpointing as described in "Training deep nets with sublinear memory cost." - Chen, Tianqi, et al. (2016). See tvm.relax.testing.nn.checkpoint for more details.

Parameters#

func_namestr

The name of the specific function.

require_gradsOptional[Union[relax.Var, List[relax.Var]]]

The relax variables whose adjoints is needed. Must be parameters of the given function and should not be duplicate. If it is not specified, adjoints of all parameters would be computed.

target_indexint

If the specified function has more than one return values, specify the index of the return value as the target. If it is not specified, the first return value will be the target.

Returns#

rettvm.ir.transform.Pass

The Pass.

Examples#

The following code shows how to use this pass:

@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tensor((), dtype="float32"):
        with R.dataflow():
            lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            # use R.sum to reduce the tensor to a scalar
            lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
            R.output(lv2)
        return lv2

After = relax.transform.Gradient("main")(Module)

The module after the Gradient pass will be:

@I.ir_module
class After:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tensor((), dtype="float32"):
        with R.dataflow():
            lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
            R.output(lv2)
        return lv2

    @R.function
    def main_adjoint(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(
        R.Tensor((), dtype="float32"),
        R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")),
    ):
        with R.dataflow():
            # original bindings
            lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
            # bindings w.r.t. intermediate variables
            lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), dtype="float32")
            lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(
                lv2_adjoint, (3, 3)
            )
            # bindings w.r.t. parameters
            x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
            y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
            R.output(lv2, x_adjoint, y_adjoint)
        # return value: (orig_return_values, tuple(adjoints))
        return (lv2, (x_adjoint, y_adjoint))

The second example is returning multiple values and specifying the target with target_index:

@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
        with R.dataflow():
            lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            R.output(lv1, lv2)
        return (lv1, lv2)

After = relax.transform.Gradient("main", target_index=1)(Module)

The module after the Gradient pass will be:

@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
        with R.dataflow():
            lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            R.output(lv1, lv2)
        return (lv1, lv2)

    @R.function
    def main_adjoint(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(
        R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")),
        R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")),
    ):
        with R.dataflow():
            # original bindings
            lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            # bindings w.r.t. intermediate variables
            # gradient of intermediate variables that is not related to the target will not
            # be calculated
            lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), dtype="float32")
            # bindings w.r.t. parameters
            x_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros((3, 3), dtype="float32")
            y_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(
                lv2_adjoint, (3, 3)
            )
            R.output(lv1, lv2, x_adjoint, y_adjoint)
        # return value: (orig_return_values, tuple(adjoints))
        return ((lv1, lv2), (x_adjoint, y_adjoint))
参数:
返回类型:

Pass

tvm.relax.transform.InlinePrivateFunctions()[源代码]#

Inline all private relax functions

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.KillAfterLastUse()[源代码]#

Drop all tensor/storage objects after last use

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.LambdaLift()[源代码]#

A pass that lifts local functions into global.

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.LazyGetInput()[源代码]#

A pass that requests inputs lazily.

In many cases, the size of the model weights exceeds the available memory on a GPU. In these cases, a function that accepts all model weights as arguments would not be able to be called. In these cases, parameters must be loaded as they are required by the function, and unloaded once they are no longer needed.

This pass mutates a function such that all model weights (arguments after the first func.attrs["num_input"] arguments) are loaded on demand. Rather than accepting the weights as function arguments, the function accepts a callback argument, which can load each parameter as needed. The callback accepts two arguments, first the index of the model weight, and second the name of the parameter. The callback should return the parameter as specified.

@R.function
def before(A: R.Tensor([16,32],"float32")):
    ...

@R.function
def after(fget_param: R.Callable([R.Prim('int64'), R.Object], R.Object)):
    A_untyped = fget_param(0, R.str('A'))
    A = R.match_cast(A_untyped, R.Tensor([16,32], "float32")
    ...

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.LazySetOutput()[源代码]#

A pass that sets function outputs when available

In many cases, the size of the model weights exceeds the available memory on a GPU. In these cases, a function that produces all model weights as a single return value would not be able to be called. In these cases, parameters must be returned as they are produced, unloaded from the GPU (or saved to disk), before producing additional outputs.

This pass mutates a function such that all outputs from a function are returned when they are available. The function accepts an additional callback argument, which is called with each output of the function. The callback accepts two arguments, first the index of the output tuple that was produced (or zero if the output is not a tuple), and second the value itself.

@R.function
def before(args):
    ...
    return (A, B)

@R.function
def after(args, fset_param: R.Callable([R.Prim('int64'), R.Object])):
    ...
    fset_param(0, A)
    ...
    fset_param(1, B)
    ...
    return ()

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.LegalizeOps(customize_legalize_map=None, enable_warning=False)[源代码]#

Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs.

For each high-level operator, we register the way of legalizing it as a function, which takes a context BlockBuilder and the Call being legalized as input, and returns the legalized call. Here the input BlockBuilder is mainly used for adding the PrimFunc created by call_te into the context IRModule.

The legalization function for each operator is registered as an attribute (with attribute key FLegalize) of the operator.

This pass provides customizability for users to use their own legalization function for operators. The pass takes an optional customized map, with the key to be the operator name (str) and value to be the function (LegalizeFunc). The default legalization function will be overridden by the customized one.

Parameters#

customize_legalize_mapOptional[Dict[str, LegalizeFunc]]

The customized operator legalization function map. The customized function will override the default one.

enable_warningbool

A boolean value indicating if to print warnings for CallNode whose op's legalization function is not registered. By default we don't print warnings.

Returns#

rettvm.transform.Pass

The registered pass

Examples#

The following code shows how to use this pass:

# Define the pass input IRModule
@tvm.script.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        z: R.Tensor((2, 3), "float32") = R.add(x, y)
        r: R.Tensor((2, 3), "float32") = R.multiply(y, z)
        return r

# Define the customized legalization function for "relax.add"
def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr:
    from tvm import topi
    return bb.call_te(topi.add, call.args[1], call.args[0])

# Apply the pass with the customized function to the module.
mod = LegalizeOps({"relax.add": customize_legalize_add})(Module)

Print out the result by mod.show(), we can see the IRModule after legalization becomes

@tvm.script.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        z = R.call_tir(add, (y, x), (2, 3), dtype="float32")
        r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32")
        return r

    @T.prim_func
    def add(
        A: T.Buffer((2, 3), "float32"),
        B: T.Buffer((2, 3), "float32"),
        T_add: T.Buffer((2, 3), "float32"),
    ):
        T.func_attr({"tir.noalias": True})
        for ax0, ax1 in T.grid(2, 3):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func
    def multiply(
        A: T.Buffer((2, 3), "float32"),
        B: T.Buffer((2, 3), "float32"),
        T_multiply: T.Buffer((2, 3), "float32"),
    ):
        T.func_attr({"tir.noalias": True})
        for ax0, ax1 in T.grid(2, 3):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1]
参数:
tvm.relax.transform.LiftTransformParams(shared_transform=False)[源代码]#

Lift transformation of the parameters of a function.

When some inputs of the function is marked as 'parameters' (the model weights), this pass identifies the transformation of the parameters and lifts them to a separate function called transform_params. transform_params takes a tuple of the original parameters as input and returns a tuple of the transformed parameters. The original function will be rewritten to accept a tuple of transformed parameters as input.

Users are expected to invoke the transform_params function in runtime and pass the transformed parameters to the original function as input.

Parameters#

shared_transform: Union[bool, List[str]]

Indicates how the parameter transformation function will be produced

  • False (default): A separate parameter transformation function will be produced for each function with the "num_input" attribute.

  • True: A single parameter transformation function will be produced, containing the preprocessing steps common across all functions with the "num_input" attribute.

  • List[str]: A single parameter transformation function will be produced, containing the preprocessing steps common across each function whose name is in the list. Passing a list of all functions with the "num_input" attribute or an empty list is equivalent to passing True.

Returns#

rettvm.transform.Pass

The registered pass for lifting transformation of parameters.

参数:

shared_transform (bool | List[str])

返回类型:

Pass

tvm.relax.transform.LowerAllocTensor()[源代码]#

Lower remaining instances of R.builtin.alloc_tensor

The static memory planner removes static instances of R.builtin.alloc_tensor, replacing with R.memory.alloc_storage and R.memory.alloc_tensor. However, R.builtin.alloc_tensor still remains for any dynamic allocations.

This transform replaces any remaining R.builtin.alloc_tensor instances with R.memory.alloc_storage and R.memory.alloc_tensor. If no R.builtin.alloc_tensor are present, this pass has no effect.

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.LowerRuntimeBuiltin()[源代码]#

Lowering generic intrinsic to VM intrinsics.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.MergeCompositeFunctions()[源代码]#

Group one or multiple composite functions created by FuseOpsByPattern into a new function. The new function will be annotated with "Codegen" and "global_symbol" attributes, and it is intented to be offloaded to an external backend.

Returns#

rettvm.transform.Pass

The registered pass for merging composite functions.

返回类型:

Pass

tvm.relax.transform.MetaScheduleApplyDatabase(work_dir=None, enable_warning=False)[源代码]#

Apply the best schedule from tuning database.

Parameters#

work_dirOptional[str]

work directory to deduce default database if database is not provided (it will be ignored when an user passes database)

enable_warningbool

A boolean value indicating if to print warnings for TIR functions not showing up in the database. By default we don't print warning.

Returns#

rettvm.transform.Pass

The registered pass

参数:
  • work_dir (str | None)

  • enable_warning (bool)

返回类型:

Pass

tvm.relax.transform.MetaScheduleTuneIRMod(params, work_dir, max_trials_global, max_trials_per_task=None, op_names=None)[源代码]#

Tune Relax IRModule with MetaSchedule.

Parameters#

params: Dict[str, NDArray]

model params

work_dir: str

work directory

max_trials_gloabl: int

maximum number of total trials allowed for tuning

max_trials_per_task: int

maximum number of trials per task

op_names: Optional[List[str]]

A list of operator names to specify which op to tune. When it is None, all operators are tuned.

Returns#

ret: tvm.ir.transform.Pass

参数:
返回类型:

Pass

tvm.relax.transform.MetaScheduleTuneTIR(work_dir, max_trials_global)[源代码]#

Tune TIR with MetaSchedule. Parameters ---------- work_dir: str

work directory

max_trials_gloabl: int

maximum number of total trials allowed for tuning

Returns#

ret: tvm.ir.transform.Pass

参数:
  • work_dir (str)

  • max_trials_global (int)

返回类型:

Pass

tvm.relax.transform.Normalize()[源代码]#

Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.NormalizeGlobalVar()[源代码]#

Possibly rename the GlobalVar in an IRModule to ensure these properties:

1. (Invariant) First ensure every public function has the same name as its "global_symbol" attribute 2. To ensure 1., we may need to rename private functions with conflicting names; 3. Finally, the name of every GlobalVar is unique in the IRModule.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.RealizeVDevice()[源代码]#

Propagate virtual device information.

Returns#

rettvm.transform.Pass

The registered pass

返回类型:

Pass

tvm.relax.transform.RemovePurityChecking()[源代码]#

Activate relax.force_pure on all pure functions in the module and unwrap all pure override ops into the normal versions.

This effectively means that there will be no more purity tracking, useful for low-level code generation.

Returns#

ret: tvm.ir.transform.Pass

The Pass.

Note#

Should be used after ToNonDataflow()

返回类型:

Pass

tvm.relax.transform.RemoveUnusedOutputs()[源代码]#

Remove unused outputs from internal functions

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.RemoveUnusedParameters()[源代码]#

Remove unused arguments to internal functions

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.ReorderPermuteDimsAfterConcat()[源代码]#

Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B))

Useful for optimizing computations after CombineParallelMatmul. The patterns for optimized nn.Linear implementations look for matmul(activations, permute_dims(weights)). After CombineParallelMatmul, the matmul(activations, concat(permute_dims(A), permute_dims(B))) no longer matches this pattern. Rearranging into matmul(activations, permute_dims(concat(A,B))) restores the pattern match.

Returns#

rettvm.transform.Pass

The corresponding pass.

tvm.relax.transform.ReorderTakeAfterMatmul()[源代码]#

Reorder matmul(x, take(weights, indices)) to take(matmul(x,weights),indices)

Useful for optimizing LoRA computations, where several LoRAs may be batched together.

Returns#

rettvm.transform.Pass

The corresponding pass.

tvm.relax.transform.RewriteCUDAGraph()[源代码]#

Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be executed with CUDA graph and lifts them into new functions for runtime graph capturing.

Returns#

ret: tvm.ir.transform.Pass

The registered pass for rewriting cuda graph

返回类型:

Pass

tvm.relax.transform.RewriteDataflowReshape()[源代码]#

Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView operation at runtime, instead of doing real data copy. Here "reshape-like" includes reshape, expand_dims, flatten, etc.

Note: Operates only in dataflow blocks. ConvertToDataflow may need to be called first.

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.RunCodegen(target_options=None, entry_functions=None)[源代码]#

Produce the runtime::Module with an annotated codegen and global symbol.

Parameters#

target_options: Optional[dict]

Pairs of a target name and compilation options

entry_functions: Optional[List[str]]

The set of entry functions to start from.

Returns#

rettvm.transform.Pass

The registered pass to remove unused functions.

参数:
  • target_options (dict | None)

  • entry_functions (List[str] | None)

返回类型:

Pass

tvm.relax.transform.SplitCallTIRByPattern(patterns, fcodegen)[源代码]#
Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is

matched with some pattern, and the second part is the rest of the original PrimFunc. It will call fcodegen to generate the code for the matched pattern to replace it with a ExternFunc call.

Parameters#

patternsList[PrimFunc]

The list of patterns to match.

fcodegen: Callable[[List[MatchResult]], List[Object]]

The function to generate the code for the matched patterns.

Returns#

rettvm.transform.Pass

The registered pass for splitting call_tir.

参数:
返回类型:

Pass

tvm.relax.transform.SplitLayoutRewritePreproc()[源代码]#

Split the TIR layout rewrite into multiple TIR functions. This pass is used in the prepack weight after meta_schedule tuning.

Returns#

rettvm.transform.Pass

The registered pass for splitting TIR layout rewrite.

返回类型:

Pass

tvm.relax.transform.StaticPlanBlockMemory()[源代码]#

The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its best effort, in order to reduce the total amount of allocated memory size.

The pass "supports" dynamic shape in the way of TIR variable upper bound annotation. We can optionally annotate the attribute "tir_var_upper_bound" to Relax functions. The attribute value is a dict from strings to integers, denoting the name of TIR variables to the upper bound values of the TIR vars. Note: The annotated upper bound attribute only applies to TIR vars in the function signature for clarity.

For example, we can annotate a Relax function with R.func_attr({"tir_var_upper_bound": {"n": 1024}}). It means the maximum value of variable that names "n" in the function signature will have upper bound 1024. And we will use 1024 as its value during memory planning.

Returns#

ret : tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.ToMixedPrecision(out_dtype='float32', fp16_input_names=None)[源代码]#

Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only, and will automatically cast fp32 to fp16 for certain ops.

Note: Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.

Parameters#

out_dtypestr

The output data type of gemm/conv, which is the data type of the accumulator.

fp16_input_namesList[str]

The names of function parameters whose dtype should become fp16. The function signature would change accordingly.

Returns#

rettvm.transform.Pass

The registered pass for mixed precision.

参数:

fp16_input_names (List[str] | None)

返回类型:

Pass

tvm.relax.transform.ToNonDataflow()[源代码]#

Transform all dataflow structure to non-dataflow version.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.TopologicalSort(order='depth-first', direction='from-inputs')[源代码]#

Sort bindings in relax.Dataflow blocks in the order specified

Parameters#

order: str

The order in which bindings should be emitted. Allowed values are "depth-first" and "breadth-first".

direciton: str

The direction in which the sort should be performed. Allowed values are "from-inputs" and "from-outputs".

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.UpdateParamStructInfo(sinfo_func)[源代码]#

Update struct info of parameters

Update struct info of parameters. Internal bindings and function return type will be updated using relax's struct inference rules. Errors resulting from struct inference will be propagated to the user.

Parameters#

sinfo_func: Callable[[Var], Optional[StructInfo]]

A function that is called once for each function parameter, and returns the updated struct info to be used for it. If the function returns None, the parameter is not modified.

Returns#

rettvm.transform.Pass

The corresponding pass.

参数:

sinfo_func (Callable[[Var], StructInfo | None])

tvm.relax.transform.UpdateVDevice(new_vdevice, index)[源代码]#

Update virtual device.

Parameters#

new_vdevicetvm.ir.VDevice

The new virtual device.

indexint

The device index indicates the device on which the update will be performed.

Returns#

rettvm.ir.transform.Pass

The registered pass that modifies the virtual device.

参数:
  • new_vdevice (VDevice)

  • index (int)

返回类型:

Pass

tvm.relax.transform.VMBuiltinLower()[源代码]#

Lowering generic intrinsic to VM intrinsics.

Returns#

ret: tvm.ir.transform.Pass

返回类型:

Pass

tvm.relax.transform.VMShapeLower(*, emit_err_ctx=True)[源代码]#

Lower the symbolic shape and argument and match-cast structinfo matching.

Parameters#

emit_err_ctx: Optional[bool]

Whether emit err context string, can be turned off for testing purposes.

Returns#

ret: tvm.ir.transform.Pass

参数:

emit_err_ctx (bool)

返回类型:

Pass

tvm.relax.transform.dataflowblock_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)[源代码]#

Decorate a dataflowblock pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created dataflowblock pass using the given optimization function.

Parameters#

pass_funcOptional[Callable[(DataflowBlock, Module, PassContext) -> DataflowBlock]]

The transformation function or class.

opt_levelint

The optimization level of this dataflowblock pass.

nameOptional[str]

The name of the dataflowblock pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

requiredOptional[List[str]]

The list of passes that the dataflowblock pass is dependent on.

traceable: Boolean

Boolean variable whether the dataflowblock pass is traceable

Returns#

create_dataflowblock_pass : Union[Callable, DataflowBlockPass]

A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new DataflowBlockPass will be returned when we decorate a pass function. A new DataflowBlockPass class will be returned when we decorate a class type.

Examples#

The following code block decorates a dataflowblock pass class.

@relax.transform.dataflowblock_pass(opt_level=1)
class TestReplaceBinding:
    # Simple test function to replace the first VarBinding to another.

    def __init__(self):
        # create a new VarBinding
        m, n = tir.Var("m", "int64"), tir.Var("n", "int64")
        lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32"))
        val = relax.const(np.random.rand(24, 56))
        self.new_binding = relax.VarBinding(lv0, val)

    def transform_dataflowblock(self, block, mod, ctx):
        # just for demo purposes
        # Replace the first binding in the DataflowBlock
        new_bindings = [self.new_binding, block.bindings[1]]
        new_block = relax.expr.DataflowBlock(new_bindings, block.span)
        return new_block

@tvm.script.ir_module
class InputMod:
    @R.function
    def f1(x: Tensor[(m, n), "float32"]):
        with relax.dataflow():
            lv0 = relax.multiply(x, x)
            gv0 = relax.add(x, x)
            relax.output(gv0)
        return gv0
# block_pass is now a special pass that replaces every
# first binding to the constant value binding
block_pass = TestReplaceBinding()
# now every first binding in DataflowBlock of InputMod
# is replaced by new_binding
updated_mod = block_pass(InputMod)

The following code creates a dataflowblock pass by decorating a user defined transform function.

@relax.transform.dataflowblock_pass(opt_level=2)
def transform(block, mod, ctx):
    # my transformations here.
    return block

block_pass = transform
assert isinstance(block_pass, relax.transform.DataflowBlockPass)
assert block_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = block_pass(m)
# Now transform should have been applied to every DataflowBlock in
# the provided module m. And the updated module will be returned.
返回类型:

Callable | DataflowBlockPass

tvm.relax.transform.function_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)[源代码]#

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters#

pass_funcOptional[Callable[(Function, Module, PassContext) -> Function]]

The transformation function or class.

opt_levelint

The optimization level of this function pass.

nameOptional[str]

The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

requiredOptional[List[str]]

The list of passes that the function pass is dependent on.

traceable: Boolean

Boolean variable whether the function pass is traceable

Returns#

create_function_pass : Union[Callable, FunctionPass]

A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Examples#

The following code block decorates a function pass class.

@relax.transform.function_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

@R.function
def f1(x: Tensor[(m, n), "float32"]):
    return x

@tvm.script.ir_module
class InputMod:
    @R.function
    def f2(x: Tensor[(m, n), "float32"]):
        gv0 = relax.add(x, x)
        return gv0
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# now every function in InputMod is replaced by f1
updated_mod = fpass(InputMod)

The following code creates a function pass by decorating a user defined transform function.

@relax.transform.function_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, relax.transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now transform should have been applied to every function in
# the provided module m. And the updated module will be returned.
返回类型:

Callable | FunctionPass