nn.module
#
使用类似 PyTorch 的 API 创建、编译和运行神经网络的示例代码。
import tvm
from tvm.relay import Call
from tvm import relax, tir
from tvm.relax.testing import nn
from tvm.script import relax as R
import numpy as np
builder = relax.BlockBuilder()
n = tir.Var("n", "int64") # 符号变量用于表示小批量大小。
input_size = 784
hidden_sizes = [128, 32]
output_size = 10
构建用于分类任务的三层线性神经网络。
with builder.function("main"):
model = nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(),
)
data = nn.Placeholder((n, input_size), name="data")
output = model(data)
params = [data] + model.parameters()
builder.emit_func_output(output, params=params)
获取并打印正在构建的 IRmodule。
mod = builder.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(
var_A: T.handle, B: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle
):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(128)))
T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax1]
@T.prim_func
def add1(
var_A: T.handle, B: T.Buffer((T.int64(32),), "float32"), var_T_add: T.handle
):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(32)))
T_add = T.match_buffer(var_T_add, (n, T.int64(32)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(32)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax1]
@T.prim_func
def add2(
var_A: T.handle, B: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle
):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(10)))
T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax1]
@T.prim_func
def log_softmax(var_A: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(10)))
compute = T.match_buffer(var_compute, (n, T.int64(10)))
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((n,))
compute_1 = T.alloc_buffer((n,))
for i0, k in T.grid(n, T.int64(10)):
with T.block("T_softmax_maxelem"):
v_i0, v_k = T.axis.remap("SR", [i0, k])
T.reads(A[v_i0, v_k])
T.writes(T_softmax_maxelem[v_i0])
with T.init():
T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e38)
T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k])
for i0, k in T.grid(n, T.int64(10)):
with T.block("compute"):
v_i0, v_k = T.axis.remap("SR", [i0, k])
T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0])
T.writes(compute_1[v_i0])
with T.init():
compute_1[v_i0] = T.float32(0)
compute_1[v_i0] = compute_1[v_i0] + T.exp(
A[v_i0, v_k] - T_softmax_maxelem[v_i0]
)
for i0, i1 in T.grid(n, T.int64(10)):
with T.block("compute_1"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], compute_1[v_i0])
T.writes(compute[v_i0, v_i1])
T.block_attr({"axis": 1})
compute[v_i0, v_i1] = (
A[v_i0, v_i1] - T_softmax_maxelem[v_i0] - T.log(compute_1[v_i0])
)
@T.prim_func
def matmul(
var_A: T.handle,
B: T.Buffer((T.int64(784), T.int64(128)), "float32"),
var_T_matmul: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(784)))
T_matmul = T.match_buffer(var_T_matmul, (n, T.int64(128)))
# with T.block("root"):
for ax0, ax1, k in T.grid(n, T.int64(128), T.int64(784)):
with T.block("T_matmul"):
v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
T.reads(A[v_ax0, v_k], B[v_k, v_ax1])
T.writes(T_matmul[v_ax0, v_ax1])
with T.init():
T_matmul[v_ax0, v_ax1] = T.float32(0)
T_matmul[v_ax0, v_ax1] = (
T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_k, v_ax1]
)
@T.prim_func
def matmul1(
var_A: T.handle,
B: T.Buffer((T.int64(128), T.int64(32)), "float32"),
var_T_matmul: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(128)))
T_matmul = T.match_buffer(var_T_matmul, (n, T.int64(32)))
# with T.block("root"):
for ax0, ax1, k in T.grid(n, T.int64(32), T.int64(128)):
with T.block("T_matmul"):
v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
T.reads(A[v_ax0, v_k], B[v_k, v_ax1])
T.writes(T_matmul[v_ax0, v_ax1])
with T.init():
T_matmul[v_ax0, v_ax1] = T.float32(0)
T_matmul[v_ax0, v_ax1] = (
T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_k, v_ax1]
)
@T.prim_func
def matmul2(
var_A: T.handle,
B: T.Buffer((T.int64(32), T.int64(10)), "float32"),
var_T_matmul: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(32)))
T_matmul = T.match_buffer(var_T_matmul, (n, T.int64(10)))
# with T.block("root"):
for ax0, ax1, k in T.grid(n, T.int64(10), T.int64(32)):
with T.block("T_matmul"):
v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
T.reads(A[v_ax0, v_k], B[v_k, v_ax1])
T.writes(T_matmul[v_ax0, v_ax1])
with T.init():
T_matmul[v_ax0, v_ax1] = T.float32(0)
T_matmul[v_ax0, v_ax1] = (
T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_k, v_ax1]
)
@T.prim_func
def relu(var_A: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(128)))
compute = T.match_buffer(var_compute, (n, T.int64(128)))
# with T.block("root"):
for i0, i1 in T.grid(n, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(A[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(A[v_i0, v_i1], T.float32(0))
@T.prim_func
def relu1(var_A: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(32)))
compute = T.match_buffer(var_compute, (n, T.int64(32)))
# with T.block("root"):
for i0, i1 in T.grid(n, T.int64(32)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(A[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(A[v_i0, v_i1], T.float32(0))
@R.function
def main(
data: R.Tensor(("n", 784), dtype="float32"),
linear_weight: R.Tensor((784, 128), dtype="float32"),
linear_bias: R.Tensor((128,), dtype="float32"),
linear_weight1: R.Tensor((128, 32), dtype="float32"),
linear_bias1: R.Tensor((32,), dtype="float32"),
linear_weight2: R.Tensor((32, 10), dtype="float32"),
linear_bias2: R.Tensor((10,), dtype="float32"),
) -> R.Tensor(dtype="float32", ndim=2):
n = T.int64()
cls = Module
gv = R.call_tir(
cls.matmul,
(data, linear_weight),
out_sinfo=R.Tensor((n, 128), dtype="float32"),
)
gv1 = R.call_tir(
cls.add, (gv, linear_bias), out_sinfo=R.Tensor((n, 128), dtype="float32")
)
gv2 = R.call_tir(
cls.relu, (gv1,), out_sinfo=R.Tensor((n, 128), dtype="float32")
)
gv3 = R.call_tir(
cls.matmul1,
(gv2, linear_weight1),
out_sinfo=R.Tensor((n, 32), dtype="float32"),
)
gv4 = R.call_tir(
cls.add1, (gv3, linear_bias1), out_sinfo=R.Tensor((n, 32), dtype="float32")
)
gv5 = R.call_tir(
cls.relu1, (gv4,), out_sinfo=R.Tensor((n, 32), dtype="float32")
)
gv6 = R.call_tir(
cls.matmul2,
(gv5, linear_weight2),
out_sinfo=R.Tensor((n, 10), dtype="float32"),
)
gv7 = R.call_tir(
cls.add2, (gv6, linear_bias2), out_sinfo=R.Tensor((n, 10), dtype="float32")
)
gv8 = R.call_tir(
cls.log_softmax, (gv7,), out_sinfo=R.Tensor((n, 10), dtype="float32")
)
return gv8
构建 IRModule 并创建 relax vm。
target = tvm.target.Target("llvm", host="llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
在 relax vm 上运行模型,输入数据的小批量大小为 3。
params = nn.init_params(mod) # 初始化参数
data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))
res = vm["main"](data, *params)
print(res)
[[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851
-2.3025851 -2.3025851 -2.3025851 -2.3025851]
[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851
-2.3025851 -2.3025851 -2.3025851 -2.3025851]
[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851
-2.3025851 -2.3025851 -2.3025851 -2.3025851]]