理解 TensorIR 抽象#
TensorIR 是 Apache TVM 中的张量程序抽象,它是标准机器学习编译框架之一。张量程序抽象的主要目标是描述循环及其相关的硬件加速选项,包括线程化、应用专用硬件指令以及内存访问。
为了帮助我们的解释,使用以下张量计算序列作为启发性的例子。具体来说,对于两个 \(128 \times 128\) 的矩阵 A
和 B
,执行以下两步张量计算。
上述计算与神经网络中常见的基本张量函数相似,即带有 ReLU 激活的线性层。使用 TensorIR 来描述上述计算,如下所示。
在调用 TensorIR 之前,先用原生的 Python 代码结合 NumPy 来展示计算过程:
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j in range(128):
for k in range(128):
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
在记住了低级别的 NumPy 示例之后,现在准备介绍 TensorIR。下面的代码块展示了 mm_relu
的TensorIR实现。这段特定的代码是用一种名为 TVMScript 的语言编写的,这是一种嵌入在 Python AST 中的特定领域方言。
@tvm.script.ir_module
class MyModule:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
接下来,将分析上述 TensorIR 程序中的各元素。
函数参数与缓冲区#
函数参数与 numpy 函数上的同一组参数相对应。
# TensorIR
def mm_relu(A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"]):
...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
...
在此,变量 A
、B
和 C
采用了名为 T.Buffer
的类型,该类型具有形状参数 (128, 128)
和数据类型 float32
。这些额外的信息有助于可能的 MLC 处理过程生成专门针对该形状和数据类型的代码。
同样,TensorIR 在中间结果的分配中也采用了缓冲区类型。
# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")
循环迭代#
循环迭代也存在直接的对应关系。
T.grid
是 TensorIR 中的语法糖,它允许编写多个嵌套迭代器。
# TensorIR with `T.grid`
for i, j, k in T.grid(128, 128, 128):
...
# TensorIR with `range`
for i in range(128):
for j in range(128):
for k in range(128):
...
# NumPy
for i in range(128):
for j in range(128):
for k in range(128):
...
计算块#
显著的区别在于计算语句:TensorIR 引入了额外的构造,称为 T.block
。
# TensorIR
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
在 TensorIR 中,块 表示基础的计算单元。重要的是,块包含的信息比标准的 NumPy 代码多。它包括一组块轴 (vi, vj, vk)
和围绕它们的计算。
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
上述三行声明了关于块轴的 关键属性,在以下语法中。"
[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])
这三行文字传达了以下细节:
它们指定了
vi
、vj
和vk
(在本例中分别对应于i
、j
和k
)的绑定关系。他们声明最初的范围是为
vi
,vj
,vk
(即T.axis.spatial(128, i)
中的 128)准备的。他们声明了迭代器的属性(spatial, reduce)。
块轴属性#
Let's delve deeper into the properties of the block axis. These properties signify the axis's
relationship to the computation in progress. The block comprises three axes vi
, vj
, and
vk
, meanwhile the block reads the buffer A[vi, vk]
, B[vk, vj]
and writes the buffer
Y[vi, vj]
. Strictly speaking, the block performs (reduction) updates to Y, which we label
as write for the time being, as we don't require the value of Y from another block.
重要的是,对于固定的 vi
和 vj
值,计算块会在空间位置 Y
(即 Y[vi, vj]
)处产生点值,这个点值与其他位置在“Y”上的值(具有不同的 vi
,vj
值)是独立的。可以将 vi
、vj
称为 空间轴,因为它们直接对应于该块写入的缓冲区的空间区域的起始位置。参与归约操作的轴(vk
)被指定为 归约轴。
为什么 Block 中需要额外信息#
一个关键观察是,额外的信息(块轴的范围及其属性)使得该块在它需要执行的迭代中 自包含,独立于外部循环嵌套 i, j, k
。
块轴信息还提供了其他属性,帮助验证用于执行计算的外部循环是否正确。例如,上述代码块会导致错误,因为循环期望大小为128的迭代器,而只将其绑定到大小为 127 的 for 循环上。
# wrong program due to loop and block iteration mismatch
for i in range(127):
with T.block("C"):
vi = T.axis.spatial(128, i)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
error here due to iterator size mismatch
...
块轴绑定的语法糖#
在每个块轴直接映射到外部循环迭代器的情况下,可以使用 T.axis.remap
来在一行中声明块轴。
# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
等同于
vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)
因此,也可以如下形式编写程序。
@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))