张量程序抽象#
在深入了解 TensorIR 的细节之前,首先介绍一下什么是元张量函数。元张量函数是指对应于单一计算操作“单元”的函数。例如,卷积运算可以是元张量函数,而融合了卷积和 ReLU 算子也可以是元张量函数。通常,对元张量函数实现的典型抽象包含以下元素:多维缓冲区、驱动张量计算的循环嵌套结构,以及最终的计算语句本身。
from tvm.script import tir as T
@T.prim_func
def main(
A: T.Buffer((128,), "float32"),
B: T.Buffer((128,), "float32"),
C: T.Buffer((128,), "float32"),
) -> None:
for i in range(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A[vi] + B[vi]
张量程序的关键要素#
展示的元张量函数计算两个向量的元素级和。该函数:
该函数接受三个 多维缓冲区 作为参数,并生成一个 多维缓冲区 作为输出。
包含单独的 循环嵌套
i
,这有助于进行计算。包含独特的 计算语句,它计算两个向量的逐元素和。
TensorIR 中的额外结构#
至关重要的是,无法对程序执行任意变换,因为某些计算依赖于循环的顺序。幸运的是,我们关注的主要元张量函数具有有利的属性,例如循环迭代之间的独立性。例如,上述程序包括块和迭代注解:
with T.block("C")
的 块注解 表示该块是指定用于调度的基本计算单元。块可能包含单个计算语句,多个带有循环的计算语句,或者像张量核心指令(Tensor Core instructions)那样的不透明内建函数。迭代注解
T.axis.spatial
表明变量vi
映射到i
,且所有迭代都是独立的。
尽管这些信息对于执行特定的程序来说并非至关重要,但它们在程序变换过程中却显得十分有用。因此,只要遍历从 0 到 128 的所有索引元素,就可以自信地并行化或重新排序与 vi
相关的循环。