张量程序抽象#
备注
张量 (Tensor) 是表示神经网络模型执行的输入、输出和中间结果的多维数组。
张量函数 (Tensor functions) 是神经网络的“知识”的计算序列集合。
机器学习编译(MLC)的过程可以看作是张量函数的变换过程。
元张量函数(primitive tensor function)是指模型执行中的单个计算单元。MLC 过程可以选择变换元张量函数的实现方式。
张量程序(tensor program)是表示元张量函数的有效抽象。
关键成分包括: 存储数据的多维数组,驱动张量计算的循环嵌套,计算语句。
程序变换可以被用于加速张量程序的执行。
张量程序中额外的结构能够为程序变换提供更多的信息。
TensorIR#
下面着重接受一类张量程序抽象:TensorIR
。TensorIR
常用来表示循环和相关的硬件加速选择,如多线程、特殊硬件指令的使用和内存访问。
探讨如下数学表示:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
print(f'tvm 版本: {tvm.__version__}')
tvm 版本: 0.14.dev0
使用 NumPy 实现:
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b 等价于 np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)
为了更清晰理解张量程序的变换底层逻辑,约定使用低级 NumPy 实现:
在必要时使用循环而不是数组函数来展示可能的循环计算。
如果可能,总是通过
numpy.empty()
显式地分配数组并传递它们。
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j in range(128):
for k in range(128):
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
将 lnumpy_mm_relu()
翻译为 TVMScript:
@tvm.script.ir_module
class MyModule:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
MyModule.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
T.Buffer
(输入与输出)与T.alloc_buffer
(中间结果) 定义了缓冲区;T.grid
用于循环嵌套;T.block
管理着计算块。
块 是 TensorIR 中的基本计算单位。
T.block
包含一组块轴(vi、vj、vk)和围绕它们定义的计算。
[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])
T.axis.spatial
标注的信息指示循环变量是独立的,其对应的轴称为 空间轴。T.axis.reduce
对应的轴称为 归约轴。
可以使用 T.axis.remap
简写块轴语法,比如:
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce(range_of_k, k)
备注
tir.noalias
表示所有的缓冲存储器不重叠。@tvm.script.ir_module
表示MyModule
是IRModule
。IRModule
是在机器学习编译中保存张量函数集合的容器对象。
type(MyModule)
tvm.ir.module.IRModule
type(MyModule["mm_relu"])
tvm.tir.function.PrimFunc
变换#
lnumpy_mm_relu
可以改写为如下实现:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j0 in range(32):
for k in range(128):
for j1 in range(4):
j = j0 * 4 + j1
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)
TVMScript 无需重写代码,只需要:
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])
sch.reorder(j0, k, j1) # 改变循环的顺序
sch.mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, k, j_1 in T.grid(128, 32, 128, 4):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
reverse_compute_at
的原语将块 C 移动到 Y 的内循环里:
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
sch.mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 32):
for k, j_1 in T.grid(128, 4):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(4):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
在循环变换之后,我们可以将 Y 元素的初始化与归约更新分开。我们可以通过 decompose_reduction
原语来做到这一点。(注意:这也是 TVM 在以后编译的时候隐式做的,所以这一步的主要目的是让它显式,看看最终效果)。
block_Y = sch.get_block("Y", "mm_relu")
sch.decompose_reduction(block_Y, k)
sch.mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 32):
for j_1_init in range(4):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
T.reads()
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0)
for k, j_1 in T.grid(128, 4):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(4):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
构建与运行#
调用构建函数将 IRModule
变换为 runtime.Module
(表示可运行函数的集合)。 这里 target
指定了部署环境的详细信息。对于现在这种特殊情况,我们将使用 llvm
,它可以帮助我们编译到本机 CPU 平台。
rt_lib = tvm.build(MyModule, target="llvm")
创建三个用于保存输入和输出的 TVM NDArray:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd)
tvm.runtime.ndarray.NDArray
构建变换后的程序:
rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)
比较两者的时间差:
f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)
Time cost of MyModule 0.00184356 sec
Time cost of transformed sch.mod 0.000445022 sec