张量内积#
%cd ../..
import set_env
/media/pc/data/4tb/lxw/home/lxw/tvm-book/doc/tutorials
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm import te, topi
矩阵乘法分块#
有 \(\mathbf{x} = (x_1, \cdots, x_w)^T, \mathbf{y} = (y_1, \cdots, y_w)^T \in \mathbb{R}^w\),则它们的内积为
\[
\langle \mathbf{x}, \mathbf{y} \rangle = \sum_i^w x_i y_i = x^T \cdot y = x \cdot y^T \in \mathbb{R}
\]
进一步有 \(\mathbf{X} = (\mathbf{x}_1, \cdots, \mathbf{x}_h)^T \in \mathbb{R}^{h \times w}\), \(\mathbf{Y} = (\mathbf{y}_1, \cdots, \mathbf{y}_{h_o})^T \in \mathbb{R}^{h_o \times w}\),有
\[
\langle \mathbf{X}, \mathbf{Y} \rangle = \mathbf{X} \cdot \mathbf{Y}^T = (\langle \mathbf{x}_i, \mathbf{y}_j \rangle)_{i=1, j=1}^{i=h, j=h_o} \in \mathbb{R}^{h \times h_o}
\]
a_np = np.arange(24).reshape(3, 8)
b_np = np.arange(16).reshape(2, 8)
print(f"a_np:\n{a_np}\nb_np:\n{b_np}")
a_np:
[[ 0 1 2 3 4 5 6 7]
[ 8 9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]]
b_np:
[[ 0 1 2 3 4 5 6 7]
[ 8 9 10 11 12 13 14 15]]
内积参考结果:
c_np = a_np @ b_np.T
c_np
array([[ 140, 364],
[ 364, 1100],
[ 588, 1836]])
tvm 数组:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty(c_np.shape, dtype=c_np.dtype)
m = te.var("m")
n = te.var("n")
o = te.var("o")
A = te.placeholder((m, n), "int64", "X")
B = te.placeholder((o, n), "int64", "Y")
C = topi.matmul(A, B, transp_b=True) # 矩阵乘法
te_func = te.create_prim_func([A, B, C])
te_func.show()
mod = tvm.build(te_func, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)
# from tvm.script import tir as T
@T.prim_func
def func(var_X: T.handle, var_Y: T.handle, var_T_matmul: T.handle):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
o = T.var("int32")
X = T.match_buffer(var_X, [m, n], dtype="int64")
Y = T.match_buffer(var_Y, [o, n], dtype="int64")
T_matmul = T.match_buffer(var_T_matmul, [m, o], dtype="int64")
# body
# with T.block("root")
for i0, i1, i2 in T.grid(m, o, n):
with T.block("T_matmul"):
ax0, ax1, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(X[ax0, k], Y[ax1, k])
T.writes(T_matmul[ax0, ax1])
with T.init():
T_matmul[ax0, ax1] = T.int64(0)
T_matmul[ax0, ax1] = T_matmul[ax0, ax1] + X[ax0, k] * Y[ax1, k]
三维张量内积#
对于三维张量 \(\mathsf{X} = (\mathbf{X}_1, \cdots, \mathbf{X}_{c_i})^T \in \mathbb{R}^{c_i \times h \times w}\), \(\mathsf{Y} = (\mathbf{Y}_1, \cdots, \mathbf{Y}_{c_o})^T \in \mathbb{R}^{c_o \times {h_o} \times w}\),有
\[
\langle \mathsf{X}, \mathsf{Y} \rangle = \mathsf{X} \cdot \mathsf{Y}^T = (\langle \mathbf{X}_i, \mathbf{Y}_j \rangle)_{i=1, j=1}^{i=c_i, j=c_o} \in \mathbb{R}^{c_i \times h \times h_o \times c_o}
\]
mod = IRModule({"mm": te_func})
sch = tvm.tir.Schedule(mod)
block_Z = sch.get_block("T_matmul", func_name="mm")
ax0, ax1, k = sch.get_loops(block_Z)
k0, k1 = sch.split(k, factors=[None, 4])
sch.mod.show()
mod = tvm.build(sch.mod, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def mm(var_X: T.handle, var_Y: T.handle, var_T_matmul: T.handle):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
o = T.var("int32")
X = T.match_buffer(var_X, [m, n], dtype="int64")
Y = T.match_buffer(var_Y, [o, n], dtype="int64")
T_matmul = T.match_buffer(var_T_matmul, [m, o], dtype="int64")
# body
# with T.block("root")
for i0, i1, i2_0, i2_1 in T.grid(m, o, (n + 3) // 4, 4):
with T.block("T_matmul"):
T.where(i2_0 * 4 + i2_1 < n)
ax0, ax1 = T.axis.remap("SS", [i0, i1])
k = T.axis.reduce(n, i2_0 * 4 + i2_1)
T.reads(X[ax0, k], Y[ax1, k])
T.writes(T_matmul[ax0, ax1])
with T.init():
T_matmul[ax0, ax1] = T.int64(0)
T_matmul[ax0, ax1] = T_matmul[ax0, ax1] + X[ax0, k] * Y[ax1, k]
四维张量内积#
对于四维张量 \(\mathop{X} = (\mathsf{X}_1, \cdots, \mathsf{X}_{b_i})^T \in \mathbb{R}^{b_i \times c_i \times h \times w}\), \(\mathop{Y} = (\mathsf{Y}_1, \cdots, \mathsf{Y}_{b_o})^T \in \mathbb{R}^{b_o \times c_o \times {h_o} \times w}\),有
\[
\langle \mathop{X}, \mathop{Y} \rangle = X \cdot Y^T = (\langle \mathsf{X}_i, \mathsf{Y}_j \rangle)_{i=1, j=1}^{i=b_i, j=b_o} \in \mathbb{R}^{b_i \times c_i \times h \times h_o \times c_o \times b_o}
\]
mod = IRModule({"mm": te_func})
sch = tvm.tir.Schedule(mod)
block_Z = sch.get_block("T_matmul", func_name="mm")
ax0, ax1, k = sch.get_loops(block_Z)
k0, k1, k2 = sch.split(k, factors=[None, 2, 2])
sch.mod.show()
mod = tvm.build(sch.mod, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def mm(var_X: T.handle, var_Y: T.handle, var_T_matmul: T.handle):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
o = T.var("int32")
X = T.match_buffer(var_X, [m, n], dtype="int64")
Y = T.match_buffer(var_Y, [o, n], dtype="int64")
T_matmul = T.match_buffer(var_T_matmul, [m, o], dtype="int64")
# body
# with T.block("root")
for i0, i1, i2_0, i2_1, i2_2 in T.grid(m, o, (n + 3) // 4, 2, 2):
with T.block("T_matmul"):
T.where((i2_0 * 2 + i2_1) * 2 + i2_2 < n)
ax0, ax1 = T.axis.remap("SS", [i0, i1])
k = T.axis.reduce(n, i2_0 * 4 + i2_1 * 2 + i2_2)
T.reads(X[ax0, k], Y[ax1, k])
T.writes(T_matmul[ax0, ax1])
with T.init():
T_matmul[ax0, ax1] = T.int64(0)
T_matmul[ax0, ax1] = T_matmul[ax0, ax1] + X[ax0, k] * Y[ax1, k]
@tvm.script.ir_module
class MatmulModule:
@T.prim_func
def main(
A: T.Buffer[(1024, 1024), "float32"],
B: T.Buffer[(1024, 1024), "float32"],
C: T.Buffer[(1024, 1024), "float32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] += A[vi, vk] * B[vj, vk]
sch = tvm.tir.Schedule(MatmulModule)
i, j, k = sch.get_loops("matmul")
i, ii = sch.split(i, factors=[None, 16])
j, ji = sch.split(j, factors=[None, 16])
k, ki = sch.split(k, factors=[None, 16])
sch.reorder(i, j, k, ii, ji, ki)
sch.mod.show()
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(64, 64, 64, 16, 16, 16):
with T.block("matmul"):
vi = T.axis.spatial(1024, i_0 * 16 + i_1)
vj = T.axis.spatial(1024, j_0 * 16 + j_1)
vk = T.axis.reduce(1024, k_0 * 16 + k_1)
T.reads(A[vi, vk], B[vj, vk])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
block_mm = sch.blockize(ii)
sch.mod.show()
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
for i_0, j_0, k_0 in T.grid(64, 64, 64):
with T.block("matmul_o"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(A[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], B[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
with T.init():
for i_1, j_1 in T.grid(16, 16):
with T.block("matmul_init"):
vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
C[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
for i_1, j_1, k_1 in T.grid(16, 16, 16):
with T.block("matmul"):
vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
T.reads(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
A_reg = sch.cache_read(block_mm, 0, storage_scope="global.A_reg")
B_reg = sch.cache_read(block_mm, 1, storage_scope="global.B_reg")
sch.compute_at(A_reg, k)
sch.compute_at(B_reg, k)
write_back_block = sch.cache_write(block_mm, 0, storage_scope="global.accumulator")
sch.reverse_compute_at(write_back_block, j)
sch.mod.show()
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
A_global_A_reg = T.alloc_buffer([1024, 1024], dtype="float32", scope="global.A_reg")
B_global_B_reg = T.alloc_buffer([1024, 1024], dtype="float32", scope="global.B_reg")
C_global_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="global.accumulator")
for i_0, j_0 in T.grid(64, 64):
for k_0 in T.serial(64):
for ax0, ax1 in T.grid(16, 16):
with T.block("A_global.A_reg"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(A[v0, v1])
T.writes(A_global_A_reg[v0, v1])
A_global_A_reg[v0, v1] = A[v0, v1]
for ax0, ax1 in T.grid(16, 16):
with T.block("B_global.B_reg"):
v0 = T.axis.spatial(1024, j_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(B[v0, v1])
T.writes(B_global_B_reg[v0, v1])
B_global_B_reg[v0, v1] = B[v0, v1]
with T.block("matmul_o"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(A_global_A_reg[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], B_global_B_reg[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
T.writes(C_global_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
with T.init():
for i_1, j_1 in T.grid(16, 16):
with T.block("matmul_init"):
vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
for i_1, j_1, k_1 in T.grid(16, 16, 16):
with T.block("matmul"):
vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
T.reads(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
T.writes(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
for ax0, ax1 in T.grid(16, 16):
with T.block("C_global.accumulator"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, j_0 * 16 + ax1)
T.reads(C_global_accumulator[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_global_accumulator[v0, v1]