变换#
在本节中,将深入探讨 Relax 程序的变换。变换是编译流程中的关键组成部分,用于优化并与硬件后端进行集成。
首先,按照在 上一节 中所做的那样,创建简单的 Relax 程序。
Show code cell content
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
class NNModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
origin_mod, params = NNModule().export_tvm(
{"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
origin_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
应用变换#
Pass 是对程序应用变换的主要方式。可以对程序应用 Pass。作为第一步,让应用内置的 LegalizeOps
Pass,将高级算子降级为低级算子。
mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(var_matmul: T.handle, fc1_bias: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + fc1_bias[v_ax1]
@T.prim_func(private=True)
def add1(var_matmul1: T.handle, fc2_bias: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
matmul1 = T.match_buffer(var_matmul1, (n, T.int64(10)))
T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul1[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = matmul1[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def matmul(var_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
x = T.match_buffer(var_x, (n, T.int64(784)))
matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
# with T.block("root"):
for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
@T.prim_func(private=True)
def matmul1(var_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
relu = T.match_buffer(var_relu, (n, T.int64(128)))
matmul = T.match_buffer(var_matmul, (n, T.int64(10)))
# with T.block("root"):
for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
@T.prim_func(private=True)
def relu(var_add: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
add = T.match_buffer(var_add, (n, T.int64(128)))
compute = T.match_buffer(var_compute, (n, T.int64(128)))
# with T.block("root"):
for i0, i1 in T.grid(n, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(add[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(add[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
matmul = R.call_tir(cls.matmul, (x, permute_dims), out_sinfo=R.Tensor((n, 128), dtype="float32"))
add = R.call_tir(cls.add, (matmul, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
relu = R.call_tir(cls.relu, (add,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
matmul1 = R.call_tir(cls.matmul1, (relu, permute_dims1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
add1 = R.call_tir(cls.add1, (matmul1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
从输出可以看到,程序中的高级算子(即 relax.op
)已被相应的低级算子(即 relax.call_tir
)所取代。
接下来,尝试应用算子融合,这是机器学习编译器中广泛使用的一种优化技术。请注意,在 Relax 中,融合优化是通过一系列 Pass 的协作完成的。可以按顺序应用这些 Pass。
mod = tvm.ir.transform.Sequential(
[
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
]
)(mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(p_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
relu = T.match_buffer(p_relu, (n, T.int64(128)))
T_add_intermediate = T.match_buffer(p_output0, (n, T.int64(10)))
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((n, T.int64(10)))
for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
for ax0, ax1 in T.grid(n, T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(p_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), fc1_bias: T.Buffer((T.int64(128),), "float32"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
x = T.match_buffer(p_x, (n, T.int64(784)))
compute_intermediate = T.match_buffer(p_output0, (n, T.int64(128)))
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((n, T.int64(128)))
T_add_intermediate = T.alloc_buffer((n, T.int64(128)))
for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
for ax0, ax1 in T.grid(n, T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
for i0, i1 in T.grid(n, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
R.output(gv)
return gv
结果显示,matmul
、add
和 relu
算子被融合到了内核中(即 call_tir
)。
有关所有内置 Pass 的详细信息,请参阅 tvm.relax.transform
。
自定义 Pass#
也可以定义自己的 Pass。以下是将 relu
算子重写为 gelu
算子的示例。
首先,需要编写 Relax IR Mutator 来执行重写。
from tvm.relax.expr_functor import PyExprMutator, mutator
@mutator
class ReluRewriter(PyExprMutator):
def __init__(self, mod):
super().__init__(mod)
def visit_call_(self, call: relax.Call) -> relax.Expr:
# visit the relax.Call expr, and only handle the case when op is relax.nn.relu
if call.op.name == "relax.nn.relu":
return relax.op.nn.gelu(call.args[0])
return super().visit_call_(call)
然后,可以编写 Pass,将 Mutator 应用到整个模块中。
@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu: # pylint: disable=too-few-public-methods
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
rewriter = ReluRewriter(mod)
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
func = rewriter.visit_expr(func)
rewriter.builder_.update_func(g_var, func)
return rewriter.builder_.get()
mod = ReluToGelu()(origin_mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((n, 128), dtype="float32") = R.nn.gelu(add)
permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
打印输出显示,relax.nn.relu
运算符已被重写为 relax.nn.gelu
运算符。
有关 Mutator 的详细信息,请参阅 PyExprMutator
。
总结#
在本节中,展示了如何对 Relax 程序应用转换。还展示了如何定义和应用自定义转换。