# 张量内积

In [1]:
%cd ../..
import set_env

/media/pc/data/4tb/lxw/home/lxw/tvm-book/doc/tutorials


In [2]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm import te, topi

## 矩阵乘法分块

有 $\mathbf{x} = (x_1, \cdots, x_w)^T, \mathbf{y} = (y_1, \cdots, y_w)^T \in \mathbb{R}^w$，则它们的内积为

$$
\langle \mathbf{x}, \mathbf{y} \rangle = \sum_i^w x_i y_i = x^T \cdot y = x \cdot y^T  \in \mathbb{R}
$$

进一步有 $\mathbf{X} = (\mathbf{x}_1, \cdots, \mathbf{x}_h)^T \in \mathbb{R}^{h \times w}$, $\mathbf{Y} = (\mathbf{y}_1, \cdots, \mathbf{y}_{h_o})^T \in \mathbb{R}^{h_o \times w}$，有

$$
\langle \mathbf{X}, \mathbf{Y} \rangle = \mathbf{X} \cdot \mathbf{Y}^T =  (\langle \mathbf{x}_i, \mathbf{y}_j \rangle)_{i=1, j=1}^{i=h, j=h_o} \in \mathbb{R}^{h \times h_o}
$$

In [3]:
a_np = np.arange(24).reshape(3, 8)
b_np = np.arange(16).reshape(2, 8)
print(f"a_np:\n{a_np}\nb_np:\n{b_np}")

a_np:
[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]]
b_np:
[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]]


内积参考结果：

In [4]:
c_np = a_np @ b_np.T
c_np

array([[ 140,  364],
       [ 364, 1100],
       [ 588, 1836]])

tvm 数组：

In [5]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty(c_np.shape, dtype=c_np.dtype)

In [6]:
m = te.var("m")
n = te.var("n")
o = te.var("o")
A = te.placeholder((m, n), "int64", "X")
B = te.placeholder((o, n), "int64", "Y")
C = topi.matmul(A, B, transp_b=True) # 矩阵乘法
te_func = te.create_prim_func([A, B, C])
te_func.show()
mod = tvm.build(te_func, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)

### 三维张量内积

对于三维张量 $\mathsf{X} = (\mathbf{X}_1, \cdots, \mathbf{X}_{c_i})^T \in \mathbb{R}^{c_i \times h \times w}$, $\mathsf{Y} = (\mathbf{Y}_1, \cdots, \mathbf{Y}_{c_o})^T \in \mathbb{R}^{c_o \times {h_o} \times w}$，有

$$
\langle \mathsf{X}, \mathsf{Y} \rangle = \mathsf{X} \cdot \mathsf{Y}^T =  (\langle \mathbf{X}_i, \mathbf{Y}_j \rangle)_{i=1, j=1}^{i=c_i, j=c_o} \in \mathbb{R}^{c_i \times h \times h_o \times c_o}
$$

In [7]:
mod = IRModule({"mm": te_func})
sch = tvm.tir.Schedule(mod)
block_Z = sch.get_block("T_matmul", func_name="mm")
ax0, ax1, k = sch.get_loops(block_Z)
k0, k1 = sch.split(k, factors=[None, 4])
sch.mod.show()
mod = tvm.build(sch.mod, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)

### 四维张量内积

对于四维张量 $\mathop{X} = (\mathsf{X}_1, \cdots, \mathsf{X}_{b_i})^T \in \mathbb{R}^{b_i \times c_i \times h \times w}$, $\mathop{Y} = (\mathsf{Y}_1, \cdots, \mathsf{Y}_{b_o})^T \in \mathbb{R}^{b_o \times c_o \times {h_o} \times w}$，有

$$
\langle \mathop{X}, \mathop{Y} \rangle = X \cdot Y^T =  (\langle \mathsf{X}_i, \mathsf{Y}_j \rangle)_{i=1, j=1}^{i=b_i, j=b_o} \in \mathbb{R}^{b_i \times c_i \times h \times h_o \times c_o \times b_o}
$$

In [9]:
mod = IRModule({"mm": te_func})
sch = tvm.tir.Schedule(mod)
block_Z = sch.get_block("T_matmul", func_name="mm")
ax0, ax1, k = sch.get_loops(block_Z)
k0, k1, k2 = sch.split(k, factors=[None, 2, 2])
sch.mod.show()
mod = tvm.build(sch.mod, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)

In [10]:
@tvm.script.ir_module
class MatmulModule:
    @T.prim_func
    def main(
        A: T.Buffer[(1024, 1024), "float32"],
        B: T.Buffer[(1024, 1024), "float32"],
        C: T.Buffer[(1024, 1024), "float32"],
    ) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(1024, 1024, 1024):
            with T.block("matmul"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] += A[vi, vk] * B[vj, vk]

In [11]:
sch = tvm.tir.Schedule(MatmulModule)
i, j, k = sch.get_loops("matmul")
i, ii = sch.split(i, factors=[None, 16])
j, ji = sch.split(j, factors=[None, 16])
k, ki = sch.split(k, factors=[None, 16])
sch.reorder(i, j, k, ii, ji, ki)
sch.mod.show()

In [12]:
block_mm = sch.blockize(ii)
sch.mod.show()

In [13]:
A_reg = sch.cache_read(block_mm, 0, storage_scope="global.A_reg")
B_reg = sch.cache_read(block_mm, 1, storage_scope="global.B_reg")
sch.compute_at(A_reg, k)
sch.compute_at(B_reg, k)

write_back_block = sch.cache_write(block_mm, 0, storage_scope="global.accumulator")
sch.reverse_compute_at(write_back_block, j)
sch.mod.show()