理解 TensorIR 抽象#

TensorIR 是 Apache TVM 中的张量程序抽象,它是标准机器学习编译框架之一。张量程序抽象的主要目标是描述循环及其相关的硬件加速选项,包括线程化、应用专用硬件指令以及内存访问。

为了帮助我们的解释,使用以下张量计算序列作为启发性的例子。具体来说,对于两个 \(128 \times 128\) 的矩阵 AB,执行以下两步张量计算。

\[\begin{split}Y_{i, j} &= \sum_k A_{i, k} \times B_{k, j} \\ C_{i, j} &= \mathbb{relu}(Y_{i, j}) = \mathbb{max}(Y_{i, j}, 0)\end{split}\]

上述计算与神经网络中常见的基本张量函数相似,即带有 ReLU 激活的线性层。使用 TensorIR 来描述上述计算,如下所示。

在调用 TensorIR 之前,先用原生的 Python 代码结合 NumPy 来展示计算过程:

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

在记住了低级别的 NumPy 示例之后,现在准备介绍 TensorIR。下面的代码块展示了 mm_relu 的TensorIR实现。这段特定的代码是用一种名为 TVMScript 的语言编写的,这是一种嵌入在 Python AST 中的特定领域方言。

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

接下来,将分析上述 TensorIR 程序中的各元素。

函数参数与缓冲区#

函数参数与 numpy 函数上的同一组参数相对应。

# TensorIR
def mm_relu(A: T.Buffer[(128, 128), "float32"],
            B: T.Buffer[(128, 128), "float32"],
            C: T.Buffer[(128, 128), "float32"]):
    ...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    ...

在此,变量 ABC 采用了名为 T.Buffer 的类型,该类型具有形状参数 (128, 128) 和数据类型 float32。这些额外的信息有助于可能的 MLC 处理过程生成专门针对该形状和数据类型的代码。

同样,TensorIR 在中间结果的分配中也采用了缓冲区类型。

# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")

循环迭代#

循环迭代也存在直接的对应关系。

T.grid 是 TensorIR 中的语法糖,它允许编写多个嵌套迭代器。

# TensorIR with `T.grid`
for i, j, k in T.grid(128, 128, 128):
    ...
# TensorIR with `range`
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...
# NumPy
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...

计算块#

显著的区别在于计算语句:TensorIR 引入了额外的构造,称为 T.block

# TensorIR
with T.block("Y"):
    vi = T.axis.spatial(128, i)
    vj = T.axis.spatial(128, j)
    vk = T.axis.reduce(128, k)
    with T.init():
        Y[vi, vj] = T.float32(0)
    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
    Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

在 TensorIR 中, 表示基础的计算单元。重要的是,块包含的信息比标准的 NumPy 代码多。它包括一组块轴 (vi, vj, vk) 和围绕它们的计算。

vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)

上述三行声明了关于块轴的 关键属性,在以下语法中。"

[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])

这三行文字传达了以下细节:

  • 它们指定了 vivjvk (在本例中分别对应于 ijk)的绑定关系。

  • 他们声明最初的范围是为 vivjvk (即 T.axis.spatial(128, i) 中的 128)准备的。

  • 他们声明了迭代器的属性(spatial, reduce)。

块轴属性#

Let's delve deeper into the properties of the block axis. These properties signify the axis's relationship to the computation in progress. The block comprises three axes vi, vj, and vk, meanwhile the block reads the buffer A[vi, vk], B[vk, vj] and writes the buffer Y[vi, vj]. Strictly speaking, the block performs (reduction) updates to Y, which we label as write for the time being, as we don't require the value of Y from another block.

重要的是,对于固定的 vivj 值,计算块会在空间位置 Y (即 Y[vi, vj])处产生点值,这个点值与其他位置在“Y”上的值(具有不同的 vivj 值)是独立的。可以将 vivj 称为 空间轴,因为它们直接对应于该块写入的缓冲区的空间区域的起始位置。参与归约操作的轴(vk)被指定为 归约轴

为什么 Block 中需要额外信息#

一个关键观察是,额外的信息(块轴的范围及其属性)使得该块在它需要执行的迭代中 自包含,独立于外部循环嵌套 i, j, k

块轴信息还提供了其他属性,帮助验证用于执行计算的外部循环是否正确。例如,上述代码块会导致错误,因为循环期望大小为128的迭代器,而只将其绑定到大小为 127 的 for 循环上。

# wrong program due to loop and block iteration mismatch
for i in range(127):
    with T.block("C"):
        vi = T.axis.spatial(128, i)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
        error here due to iterator size mismatch
        ...

块轴绑定的语法糖#

在每个块轴直接映射到外部循环迭代器的情况下,可以使用 T.axis.remap 来在一行中声明块轴。

# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])

等同于

vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)

因此,也可以如下形式编写程序。

@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))