理解 TensorIR 抽象#

TensorIR 是 Apache TVM 中的张量程序抽象,它是标准机器学习编译框架之一。张量程序抽象的主要目标是描述循环及其相关的硬件加速选项,包括线程化、应用专用硬件指令以及内存访问。

为了帮助我们的解释,使用以下张量计算序列作为启发性的例子。具体来说,对于两个 128×128 的矩阵 AB,执行以下两步张量计算。

Yi,j=kAi,k×Bk,jCi,j=relu(Yi,j)=max(Yi,j,0)

上述计算与神经网络中常见的基本张量函数相似,即带有 ReLU 激活的线性层。使用 TensorIR 来描述上述计算,如下所示。

在调用 TensorIR 之前,先用原生的 Python 代码结合 NumPy 来展示计算过程:

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

在记住了低级别的 NumPy 示例之后,现在准备介绍 TensorIR。下面的代码块展示了 mm_relu 的TensorIR实现。这段特定的代码是用一种名为 TVMScript 的语言编写的,这是一种嵌入在 Python AST 中的特定领域方言。

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

接下来,将分析上述 TensorIR 程序中的各元素。

函数参数与缓冲区#

函数参数与 numpy 函数上的同一组参数相对应。

# TensorIR
def mm_relu(A: T.Buffer((128, 128), "float32"),
            B: T.Buffer((128, 128), "float32"),
            C: T.Buffer((128, 128), "float32")):
    ...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    ...

在此,变量 ABC 采用了名为 T.Buffer 的类型,该类型具有形状参数 (128, 128) 和数据类型 float32。这些额外的信息有助于可能的 MLC 处理过程生成专门针对该形状和数据类型的代码。

同样,TensorIR 在中间结果的分配中也采用了缓冲区类型。

# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")

循环迭代#

循环迭代也存在直接的对应关系。

T.grid 是 TensorIR 中的语法糖,它允许编写多个嵌套迭代器。

# TensorIR with `T.grid`
for i, j, k in T.grid(128, 128, 128):
    ...
# TensorIR with `range`
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...
# NumPy
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...

计算块#

显著的区别在于计算语句:TensorIR 引入了额外的构造,称为 T.block

# TensorIR
with T.block("Y"):
    vi = T.axis.spatial(128, i)
    vj = T.axis.spatial(128, j)
    vk = T.axis.reduce(128, k)
    with T.init():
        Y[vi, vj] = T.float32(0)
    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
    Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

在 TensorIR 中, 表示基础的计算单元。重要的是,块包含的信息比标准的 NumPy 代码多。它包括一组块轴 (vi, vj, vk) 和围绕它们的计算。

vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)

上述三行声明了关于块轴的 关键属性,在以下语法中。"

[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])

这三行文字传达了以下细节:

  • 它们指定了 vivjvk (在本例中分别对应于 ijk)的绑定关系。

  • 他们声明最初的范围是为 vivjvk (即 T.axis.spatial(128, i) 中的 128)准备的。

  • 他们声明了迭代器的属性(spatial, reduce)。

块轴属性#

更深入地探讨块轴的属性。这些属性代表了轴与进行中的计算之间的关系。该块包含三个轴 vivjvk,同时块读取缓冲区 A[vi, vk]B[vk, vj],并写入缓冲区 Y[vi, vj]。严格来说,块对 Y 执行(归约)更新操作,暂时将其标记为写操作,因为不需要来自另一个块的 Y 的值。

重要的是,对于固定的 vivj 值,计算块会在空间位置 Y (即 Y[vi, vj])处产生点值,这个点值与其他位置在“Y”上的值(具有不同的 vivj 值)是独立的。可以将 vivj 称为 空间轴,因为它们直接对应于该块写入的缓冲区的空间区域的起始位置。参与归约操作的轴(vk)被指定为 归约轴

为什么 Block 中需要额外信息#

一个关键观察是,额外的信息(块轴的范围及其属性)使得该块在它需要执行的迭代中 自包含,独立于外部循环嵌套 i, j, k

块轴信息还提供了其他属性,帮助验证用于执行计算的外部循环是否正确。例如,上述代码块会导致错误,因为循环期望大小为128的迭代器,而只将其绑定到大小为 127 的 for 循环上。

# wrong program due to loop and block iteration mismatch
for i in range(127):
    with T.block("C"):
        vi = T.axis.spatial(128, i)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
        error here due to iterator size mismatch
        ...

块轴绑定的语法糖#

在每个块轴直接映射到外部循环迭代器的情况下,可以使用 T.axis.remap 来在一行中声明块轴。

# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])

等同于

vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)

因此,也可以如下形式编写程序。

@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))