%%shell
# Installs the latest dev build of TVM from PyPI. If you wish to build
# from source, see https://tvm.apache.org/docs/install/from_source.html
pip install apache-tvm --pre
Transformation#
In this section, we will dive into the transformation of Relax programs. Transformations is one of the key ingredients of the compilation flows for optimizing and integrating with hardware backends.
Let’s first create a simple Relax program as what we have done in
the previous section <relax-creation>
.
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
class NNModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
origin_mod, params = NNModule().export_tvm(
{"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
origin_mod.show()
Apply transformations#
Passes are the main way to apply transformations to the program.
We can apply passes to the program. As first step, let’s apply
a built-in pass LegalizeOps
to lower the high-level operators
into low-level operators.
mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()
As we can see from the output, the high-level operators (aka relax.op
) in the program
are replaced by their corresponding low-level operators (aka relax.call_tir
).
Then let’s trying to apply the operator fusion, which is a wide-used optimization technique in ML compilers. Note that in relax, fusion optimizations are done with the collaboration of a set of passes. We can apply them in a sequence.
mod = tvm.ir.transform.Sequential(
[
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
]
)(mod)
mod.show()
As result, we can see that the matmul
, add
and relu
operators are fused
into one kernel (aka one call_tir
).
For all built-in passes, please refer to :py:class:relax.transform
.
Custom Passes#
We can also define our own passes. Let’s taking an example of rewrite the relu
operator to gelu
operator.
First, we need to write a Relax IR Mutator to do the rewriting.
from tvm.relax.expr_functor import PyExprMutator, mutator
@mutator
class ReluRewriter(PyExprMutator):
def __init__(self, mod):
super().__init__(mod)
def visit_call_(self, call: relax.Call) -> relax.Expr:
# visit the relax.Call expr, and only handle the case when op is relax.nn.relu
if call.op.name == "relax.nn.relu":
return relax.op.nn.gelu(call.args[0])
return super().visit_call_(call)
Then we can write a pass to apply the mutator to the whole module.
@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu: # pylint: disable=too-few-public-methods
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
rewriter = ReluRewriter(mod)
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
func = rewriter.visit_expr(func)
rewriter.builder_.update_func(g_var, func)
return rewriter.builder_.get()
mod = ReluToGelu()(origin_mod)
mod.show()
The printed output shows that the relax.nn.relu
operator is
rewritten to relax.nn.gelu
operator.
For the details of the mutator, please refer to :py:class:relax.expr_functor.PyExprMutator
.
Summary#
In this section, we have shown how to apply transformations to the Relax program. We have also shown how to define and apply custom transformations.