元张量函数变换#
在本节中,将介绍编译流程的主要组成部分——元张量函数变换。
在我们转换函数之前,让我们先评估原始实现的性能。
初始化调度#
使用提供的 MyModule 作为输入来建立 Schedule 辅助类,以此启动代码转换的过程。
sch = tvm.tir.Schedule(MyModule)
循环平铺(Tiling)#
接下来,我们执行必要的操作来获取块 Y 及其关联循环的引用。
block_Y = sch.get_block("Y")
i, j, k = sch.get_loops(block_Y)
我们现在开始执行转换。第一个修改是将循环j
拆分为两个独立的循环,内层循环的长度为4。必须了解的是,转换过程是分步进行的;因此,如果无意中再次执行块,将会出现一个错误,指出变量j
不存在。
j0, j1 = sch.split(j, factors=[None, 8])
转换的结果可以被检查,因为它被保留在sch.mod
中。
sch.mod.show()
在初始转换阶段之后,生成了两个补充循环j_0
和j_1
,它们的范围分别为32和4。接下来的动作是重新排序这两个循环。
sch.reorder(j0, k, j1)
sch.mod.show()
evaluate(sch.mod)
利用局部性#
接下来,我们将执行两个额外的转换步骤以获得不同的变体。首先,使用名为 reverse_compute_at 的原语将块 C 移动到 Y 的内层循环中。
block_C = sch.get_block("C")
sch.reverse_compute_at(block_C, j0)
sch.mod.show()
重写约简#
到目前为止,约简初始化和更新步骤一直保留在同一个块体中。这种混合形式便于循环转换,因为初始化和更新的外部循环i
, j
通常需要保持同步。
在循环转换之后,我们可以使用decompose_reduction原语将Y元素的初始化与约简更新分开。
sch.decompose_reduction(block_Y, k)
sch.mod.show()
evaluate(sch.mod)
跟踪转换#
TensorIR调度是一种过程语言,转换是以逐步方式执行的。我们可以通过打印调度或调度的历史来跟踪转换。
我们已经通过打印sch.mod
看到了调度。我们还可以通过sch.trace
打印调度的历史。
sch.trace.show()
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)
或者,我们可以结合历史跟踪输出IRModule。
sch.show()