Relax 构建#

本教程将演示如何创建 Relax 函数和程序。

介绍定义 Relax 函数的多种方法,包括使用 TVMScript 和 Relax NNModule API。

使用 TVMScript 创建 Relax 程序#

TVMScript 是一种领域特定语言,用于表示 Apache TVM 的中间表示(IR)。它是一种 Python 方言,可用于定义包含 TensorIR 和 Relax 函数的 IRModule。

在本节中,将展示如何使用 TVMScript 仅通过高级 Relax 运算符来定义简单的多层感知机(MLP)模型。

Hide code cell content
from tvm import relax, topi
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


@I.ir_module
class RelaxModule:
    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.nn.relu(lv0)
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2


RelaxModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
            lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
            lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
            lv1_1: R.Tensor((n, 128), dtype="float32") = R.nn.relu(lv0)
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
            lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
            lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
            R.output(lv2)
        return lv2

Relax 不仅是一种图级别的中间表示(IR),还支持跨级别的表示与转换。具体来说,可以在 Relax 函数中直接调用 TensorIR 函数。

Hide code cell content
@I.ir_module
class RelaxModuleWithTIR:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m), "float32")
        Y = T.match_buffer(y, (n, m), "float32")
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = RelaxModuleWithTIR
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.call_tir(cls.relu, lv0, R.Tensor((n, 128), dtype="float32"))
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2


RelaxModuleWithTIR.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m))
        Y = T.match_buffer(y, (n, m))
        # with T.block("root"):
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(X[vi, vj])
                T.writes(Y[vi, vj])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0.0))

    @R.function
    def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = Module
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
            lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
            lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
            lv1_1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
            lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
            lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
            R.output(lv2)
        return lv2

备注

你可能会注意到打印输出的内容与编写的 TVMScript 代码有所不同。这是因为以标准格式打印 IRModule,同时支持输入时的语法糖。

例如,可以将多个操作合并到一行中,如下所示:

lv0 = R.matmul(data, R.permute_dims(w0)) + b0

然而,规范化表达式要求每个绑定中只能有一个操作。因此,打印输出的内容与编写的 TVMScript 代码不同,如下所示:

lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)</p></div>

使用 NNModule API 创建 Relax 程序#

除了 TVMScript 外,还提供了一种类似 PyTorch 的 API 来定义神经网络。它的设计更加直观且易于使用。

在本节中,将展示如何使用 Relax NNModule API 来定义相同的多层感知机(MLP)模型。

from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x

在定义完 NNModule 后,可以通过 export_tvm 将其导出为 TVM IRModule。

mod, params = NNModule().export_tvm({"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}})
mod.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

还可以在 NNModule 中插入自定义的函数调用,例如 Tensor Expression (TE)、TensorIR 函数或其他 TVM 打包函数。

Hide code cell content
@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
    M, N, K = T.int64(), T.int64(), T.int64()
    X = T.match_buffer(x, (M, K), "float32")
    W = T.match_buffer(w, (N, K), "float32")
    B = T.match_buffer(b, (N,), "float32")
    Z = T.match_buffer(z, (M, N), "float32")
    for i, j, k in T.grid(M, N, K):
        with T.block("linear"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                Z[vi, vj] = 0
            Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
    for i, j in T.grid(M, N):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            Z[vi, vj] = Z[vi, vj] + B[vj]


class NNModuleWithTIR(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        n = x.shape[0]
        # We can call external functions using nn.extern
        x = nn.extern(
            "env.linear",
            [x, self.fc1.weight, self.fc1.bias],
            out=nn.Tensor.placeholder((n, 128), "float32"),
        )
        # We can also call TensorIR via Tensor Expression API in TOPI
        x = nn.tensor_expr_op(topi.nn.relu, "relu", [x])
        # We can also call other TVM packed functions
        x = nn.tensor_ir_op(
            tir_linear,
            "tir_linear",
            [x, self.fc2.weight, self.fc2.bias],
            out=nn.Tensor.placeholder((n, 10), "float32"),
        )
        return x


mod, params = NNModuleWithTIR().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(var_env_linear: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        env_linear = T.match_buffer(var_env_linear, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(env_linear[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(env_linear[v_i0, v_i1], T.float32(0.0))

    @T.prim_func
    def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, K = T.int64(), T.int64()
        X = T.match_buffer(x, (M, K))
        N = T.int64()
        W = T.match_buffer(w, (N, K))
        B = T.match_buffer(b, (N,))
        Z = T.match_buffer(z, (M, N))
        # with T.block("root"):
        for i, j, k in T.grid(M, N, K):
            with T.block("linear"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(X[vi, vk], W[vj, vk])
                T.writes(Z[vi, vj])
                with T.init():
                    Z[vi, vj] = T.float32(0.0)
                Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(M, N):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Z[vi, vj], B[vj])
                T.writes(Z[vi, vj])
                Z[vi, vj] = Z[vi, vj] + B[vj]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            env_linear = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv = R.call_tir(cls.relu, (env_linear,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv1 = R.call_tir(cls.tir_linear, (lv, fc2_weight, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            gv: R.Tensor((n, 10), dtype="float32") = lv1
            R.output(gv)
        return gv

使用 Block Builder API 创建 Relax 程序#

除了上述 API 外,我们还提供了 Block Builder API 用于创建 Relax 程序。它是一种 IR 构建器 API,更加底层,广泛用于 TVM 的内部逻辑中,例如编写自定义的 Pass。

Hide code cell content
bb = relax.BlockBuilder()
n = T.int64()
x = relax.Var("x", R.Tensor((n, 784), "float32"))
fc1_weight = relax.Var("fc1_weight", R.Tensor((128, 784), "float32"))
fc1_bias = relax.Var("fc1_bias", R.Tensor((128,), "float32"))
fc2_weight = relax.Var("fc2_weight", R.Tensor((10, 128), "float32"))
fc2_bias = relax.Var("fc2_bias", R.Tensor((10,), "float32"))
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(fc1_weight)) + fc1_bias)
        lv1 = bb.emit(relax.op.nn.relu(lv0))
        gv = bb.emit(relax.op.matmul(lv1, relax.op.permute_dims(fc2_weight)) + fc2_bias)
        bb.emit_output(gv)
    bb.emit_func_output(gv)

mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
        v = T.int64()
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            lv1: R.Tensor((v, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
            lv2: R.Tensor((v, 128), dtype="float32") = R.add(lv1, fc1_bias)
            lv3: R.Tensor((v, 128), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            lv5: R.Tensor((v, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
            lv6: R.Tensor((v, 10), dtype="float32") = R.add(lv5, fc2_bias)
            gv: R.Tensor((v, 10), dtype="float32") = lv6
            R.output(gv)
        return lv6

此外,Block Builder API 支持构建包含 Relax 函数、TensorIR 函数以及其他 TVM 打包函数的跨级别 IRModule。

Hide code cell content
bb = relax.BlockBuilder()
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(
            relax.call_dps_packed(
                "env.linear",
                [x, fc1_weight, fc1_bias],
                out_sinfo=relax.TensorStructInfo((n, 128), "float32"),
            )
        )
        lv1 = bb.emit_te(topi.nn.relu, lv0)
        tir_gv = bb.add_func(tir_linear, "tir_linear")
        gv = bb.emit(
            relax.call_tir(
                tir_gv,
                [lv1, fc2_weight, fc2_bias],
                out_sinfo=relax.TensorStructInfo((n, 10), "float32"),
            )
        )
        bb.emit_output(gv)
    bb.emit_func_output(gv)
mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(var_lv: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        v = T.int64()
        lv = T.match_buffer(var_lv, (v, T.int64(128)))
        compute = T.match_buffer(var_compute, (v, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(v, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lv[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(lv[v_i0, v_i1], T.float32(0.0))

    @T.prim_func
    def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, K = T.int64(), T.int64()
        X = T.match_buffer(x, (M, K))
        N = T.int64()
        W = T.match_buffer(w, (N, K))
        B = T.match_buffer(b, (N,))
        Z = T.match_buffer(z, (M, N))
        # with T.block("root"):
        for i, j, k in T.grid(M, N, K):
            with T.block("linear"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(X[vi, vk], W[vj, vk])
                T.writes(Z[vi, vj])
                with T.init():
                    Z[vi, vj] = T.float32(0.0)
                Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(M, N):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Z[vi, vj], B[vj])
                T.writes(Z[vi, vj])
                Z[vi, vj] = Z[vi, vj] + B[vj]

    @R.function
    def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
        v = T.int64()
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((v, 128), dtype="float32"))
            lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((v, 128), dtype="float32"))
            lv2 = R.call_tir(cls.tir_linear, (lv1, fc2_weight, fc2_bias), out_sinfo=R.Tensor((v, 10), dtype="float32"))
            gv: R.Tensor((v, 10), dtype="float32") = lv2
            R.output(gv)
        return lv2

需要注意的是,Block Builder API 不如上述 API 那样用户友好,但它是与 IR 定义紧密结合的最低层级 API。对于仅希望定义和转换机器学习模型的用户,推荐使用上述 API。然而,对于那些希望构建更复杂转换的用户,Block Builder API 则是更为灵活的选择。

总结#

本教程展示了如何针对不同的使用场景,通过 TVMScript、NNModule API、Block Builder API 以及 PackedFunc API 来创建 Relax 程序。