nn.module

nn.module#

使用类似 PyTorch 的 API 创建、编译和运行神经网络的示例代码。

import tvm
from tvm.relay import Call
from tvm import relax, tir
from tvm.relax.testing import nn
from tvm.script import relax as R
import numpy as np
builder = relax.BlockBuilder()
n = tir.Var("n", "int64") # 符号变量用于表示小批量大小。
input_size = 784
hidden_sizes = [128, 32]
output_size = 10

构建用于分类任务的三层线性神经网络。

with builder.function("main"):
    model = nn.Sequential(
        nn.Linear(input_size, hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[1], output_size),
        nn.LogSoftmax(),
    )
    data = nn.Placeholder((n, input_size), name="data")
    output = model(data)
    params = [data] + model.parameters()
    builder.emit_func_output(output, params=params)

获取并打印正在构建的 IRmodule。

mod = builder.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R


@I.ir_module
class Module:
    @T.prim_func
    def add(
        var_A: T.handle, B: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(128)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax1]

    @T.prim_func
    def add1(
        var_A: T.handle, B: T.Buffer((T.int64(32),), "float32"), var_T_add: T.handle
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(32)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(32)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(32)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax1]

    @T.prim_func
    def add2(
        var_A: T.handle, B: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(10)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax1]

    @T.prim_func
    def log_softmax(var_A: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(10)))
        compute = T.match_buffer(var_compute, (n, T.int64(10)))
        # with T.block("root"):
        T_softmax_maxelem = T.alloc_buffer((n,))
        compute_1 = T.alloc_buffer((n,))
        for i0, k in T.grid(n, T.int64(10)):
            with T.block("T_softmax_maxelem"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(A[v_i0, v_k])
                T.writes(T_softmax_maxelem[v_i0])
                with T.init():
                    T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e38)
                T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k])
        for i0, k in T.grid(n, T.int64(10)):
            with T.block("compute"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0])
                T.writes(compute_1[v_i0])
                with T.init():
                    compute_1[v_i0] = T.float32(0)
                compute_1[v_i0] = compute_1[v_i0] + T.exp(
                    A[v_i0, v_k] - T_softmax_maxelem[v_i0]
                )
        for i0, i1 in T.grid(n, T.int64(10)):
            with T.block("compute_1"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], compute_1[v_i0])
                T.writes(compute[v_i0, v_i1])
                T.block_attr({"axis": 1})
                compute[v_i0, v_i1] = (
                    A[v_i0, v_i1] - T_softmax_maxelem[v_i0] - T.log(compute_1[v_i0])
                )

    @T.prim_func
    def matmul(
        var_A: T.handle,
        B: T.Buffer((T.int64(784), T.int64(128)), "float32"),
        var_T_matmul: T.handle,
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(784)))
        T_matmul = T.match_buffer(var_T_matmul, (n, T.int64(128)))
        # with T.block("root"):
        for ax0, ax1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("T_matmul"):
                v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
                T.reads(A[v_ax0, v_k], B[v_k, v_ax1])
                T.writes(T_matmul[v_ax0, v_ax1])
                with T.init():
                    T_matmul[v_ax0, v_ax1] = T.float32(0)
                T_matmul[v_ax0, v_ax1] = (
                    T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_k, v_ax1]
                )

    @T.prim_func
    def matmul1(
        var_A: T.handle,
        B: T.Buffer((T.int64(128), T.int64(32)), "float32"),
        var_T_matmul: T.handle,
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(128)))
        T_matmul = T.match_buffer(var_T_matmul, (n, T.int64(32)))
        # with T.block("root"):
        for ax0, ax1, k in T.grid(n, T.int64(32), T.int64(128)):
            with T.block("T_matmul"):
                v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
                T.reads(A[v_ax0, v_k], B[v_k, v_ax1])
                T.writes(T_matmul[v_ax0, v_ax1])
                with T.init():
                    T_matmul[v_ax0, v_ax1] = T.float32(0)
                T_matmul[v_ax0, v_ax1] = (
                    T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_k, v_ax1]
                )

    @T.prim_func
    def matmul2(
        var_A: T.handle,
        B: T.Buffer((T.int64(32), T.int64(10)), "float32"),
        var_T_matmul: T.handle,
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(32)))
        T_matmul = T.match_buffer(var_T_matmul, (n, T.int64(10)))
        # with T.block("root"):
        for ax0, ax1, k in T.grid(n, T.int64(10), T.int64(32)):
            with T.block("T_matmul"):
                v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
                T.reads(A[v_ax0, v_k], B[v_k, v_ax1])
                T.writes(T_matmul[v_ax0, v_ax1])
                with T.init():
                    T_matmul[v_ax0, v_ax1] = T.float32(0)
                T_matmul[v_ax0, v_ax1] = (
                    T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_k, v_ax1]
                )

    @T.prim_func
    def relu(var_A: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(A[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(A[v_i0, v_i1], T.float32(0))

    @T.prim_func
    def relu1(var_A: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (n, T.int64(32)))
        compute = T.match_buffer(var_compute, (n, T.int64(32)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(32)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(A[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(A[v_i0, v_i1], T.float32(0))

    @R.function
    def main(
        data: R.Tensor(("n", 784), dtype="float32"),
        linear_weight: R.Tensor((784, 128), dtype="float32"),
        linear_bias: R.Tensor((128,), dtype="float32"),
        linear_weight1: R.Tensor((128, 32), dtype="float32"),
        linear_bias1: R.Tensor((32,), dtype="float32"),
        linear_weight2: R.Tensor((32, 10), dtype="float32"),
        linear_bias2: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(dtype="float32", ndim=2):
        n = T.int64()
        cls = Module
        gv = R.call_tir(
            cls.matmul,
            (data, linear_weight),
            out_sinfo=R.Tensor((n, 128), dtype="float32"),
        )
        gv1 = R.call_tir(
            cls.add, (gv, linear_bias), out_sinfo=R.Tensor((n, 128), dtype="float32")
        )
        gv2 = R.call_tir(
            cls.relu, (gv1,), out_sinfo=R.Tensor((n, 128), dtype="float32")
        )
        gv3 = R.call_tir(
            cls.matmul1,
            (gv2, linear_weight1),
            out_sinfo=R.Tensor((n, 32), dtype="float32"),
        )
        gv4 = R.call_tir(
            cls.add1, (gv3, linear_bias1), out_sinfo=R.Tensor((n, 32), dtype="float32")
        )
        gv5 = R.call_tir(
            cls.relu1, (gv4,), out_sinfo=R.Tensor((n, 32), dtype="float32")
        )
        gv6 = R.call_tir(
            cls.matmul2,
            (gv5, linear_weight2),
            out_sinfo=R.Tensor((n, 10), dtype="float32"),
        )
        gv7 = R.call_tir(
            cls.add2, (gv6, linear_bias2), out_sinfo=R.Tensor((n, 10), dtype="float32")
        )
        gv8 = R.call_tir(
            cls.log_softmax, (gv7,), out_sinfo=R.Tensor((n, 10), dtype="float32")
        )
        return gv8

构建 IRModule 并创建 relax vm。

target = tvm.target.Target("llvm", host="llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

在 relax vm 上运行模型,输入数据的小批量大小为 3。

params = nn.init_params(mod) # 初始化参数
data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))
res = vm["main"](data, *params)
print(res)
[[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851
  -2.3025851 -2.3025851 -2.3025851 -2.3025851]
 [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851
  -2.3025851 -2.3025851 -2.3025851 -2.3025851]
 [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851
  -2.3025851 -2.3025851 -2.3025851 -2.3025851]]