TVM 中的调度原语#
原作者: Ziheng Jiang
TVM 用于高效构建 kernel 的领域特定语言。
在本教程中,将您展示如何通过 TVM 提供的各种原语调度计算。
import tvm
from tvm import te
import numpy as np
通常有几种方法可以计算相同的结果,但是,不同的方法会导致不同的局部性(locality)和性能。因此 TVM 要求用户提供如何执行名为 Schedule (调度)的计算。
Schedule 是一组用于变换程序中计算循环的计算变换。
# 声明一些变量以备以后使用
n = te.var("n")
m = te.var("m")
p = te.var("p")
调度可以从 ops 列表中创建,默认情况下,调度以 row-major 顺序的串行方式计算张量。
# 声明矩阵元素级的乘法
A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")
s = te.create_schedule([C.op])
lower
将计算从定义转换为实际的可调用函数。使用 simple_mode=True
参数,它将返回可读的 C like 语句,在这里使用它来打印调度结果。
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
stride_2 = T.int32()
stride_3 = T.int32()
B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
stride_4 = T.int32()
stride_5 = T.int32()
C_1 = T.match_buffer(C, (m, n), strides=(stride_4, stride_5), type="auto")
for i, j in T.grid(m, n):
C_2 = T.Buffer((stride_4 * m,), data=C_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
C_2[i * stride_4 + j * stride_5] = (
A_2[i * stride + j * stride_1] * B_2[i * stride_2 + j * stride_3]
)
每个调度由多个阶段(Stage)组成,每个阶段表示一个运算的调度。
下面提供各种方法来调度每个阶段。
split#
split
可以通过 factor
将指定的轴分裂(split)为两个轴。
m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] * 2, name="B")
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
tvm.lower(s, [A, B]).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
for i_outer, i_inner in T.grid(m // 32, 32):
cse_var_1: T.int32 = i_outer * 32 + i_inner
B_2[cse_var_1 * stride_1] = A_2[cse_var_1 * stride] * T.float32(2)
for i_outer, i_inner in T.grid((m % 32 + 31) // 32, 32):
if m // 32 * 32 + i_outer * 32 + i_inner < m:
B_2[(m // 32 * 32 + i_outer * 32 + i_inner) * stride_1] = A_2[
(m // 32 * 32 + i_outer * 32 + i_inner) * stride
] * T.float32(2)
你也可以通过 nparts
分裂轴,它与 factor
分割轴相对。
m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i], name="B")
s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
for i_outer, i_inner in T.grid(32, (m + 31) // 32):
if T.likely(i_inner + i_outer * ((m + 31) // 32) < m):
B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[(i_inner + i_outer * ((m + 31) // 32)) * stride_1] = A_2[
(i_inner + i_outer * ((m + 31) // 32)) * stride
]
tile#
tile
帮助你在两个轴上逐块(tile by tile)执行计算。
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
stride_2 = T.int32()
stride_3 = T.int32()
B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10):
if T.likely(i_outer * 10 + i_inner < m):
for j_inner in range(5):
if T.likely(j_outer * 5 + j_inner < n):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[cse_var_1 * stride_2 + cse_var_2 * stride_3] = A_2[
cse_var_1 * stride + cse_var_2 * stride_1
]
fuse#
fuse
可以融合一个计算的两个连续轴。
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
stride_2 = T.int32()
stride_3 = T.int32()
B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
for i_outer, j_outer, i_inner_j_inner_fused in T.grid(
(m + 9) // 10, (n + 4) // 5, 50
):
if T.likely(i_outer * 10 + i_inner_j_inner_fused // 5 < m):
if T.likely(j_outer * 5 + i_inner_j_inner_fused % 5 < n):
cse_var_2: T.int32 = j_outer * 5 + i_inner_j_inner_fused % 5
cse_var_1: T.int32 = i_outer * 10 + i_inner_j_inner_fused // 5
B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[cse_var_1 * stride_2 + cse_var_2 * stride_3] = A_2[
cse_var_1 * stride + cse_var_2 * stride_1
]
reorder#
reorder
可以按指定的顺序重新排列坐标轴。
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axes: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
stride_2 = T.int32()
stride_3 = T.int32()
B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
for i_inner, j_outer, i_outer in T.grid(10, (n + 4) // 5, (m + 9) // 10):
if T.likely(i_outer * 10 + i_inner < m):
for j_inner in range(5):
if T.likely(j_outer * 5 + j_inner < n):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[cse_var_1 * stride_2 + cse_var_2 * stride_3] = A_2[
cse_var_1 * stride + cse_var_2 * stride_1
]
bind#
bind
可以将指定的轴与线程轴绑定,通常用于 gpu 编程。
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: A[i] * 2, name="B")
s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
n = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (n,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (n,), strides=(stride_1,), type="auto")
blockIdx_x = T.env_thread("blockIdx.x")
T.launch_thread(blockIdx_x, (n + 63) // 64)
threadIdx_x = T.env_thread("threadIdx.x")
T.launch_thread(threadIdx_x, 64)
if T.likely(blockIdx_x * 64 + threadIdx_x < n):
B_2 = T.Buffer((stride_1 * n,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * n,), data=A_1.data, type="auto")
B_2[(blockIdx_x * 64 + threadIdx_x) * stride_1] = A_2[
(blockIdx_x * 64 + threadIdx_x) * stride
] * T.float32(2)
compute_at#
对于由多个算子组成的调度,默认情况下 TVM 将分别计算根节点上的张量。
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
stride_2 = T.int32()
C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
for i in range(m):
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[i * stride_1] = A_2[i * stride] + T.float32(1)
for i in range(m):
C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
C_2[i * stride_2] = B_2[i * stride_1] * T.float32(2)
compute_at
可以将 B
的计算移到 C
的计算的第一个轴上。
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
stride_2 = T.int32()
C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
for i in range(m):
B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[i * stride_1] = A_2[i * stride] + T.float32(1)
C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
C_2[i * stride_2] = B_2[i * stride_1] * T.float32(2)
compute_inline#
compute_inline
可以将一个阶段标记为内联,然后将计算体扩展并插入到需要张量的地址。
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
s[B].compute_inline()
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
stride_2 = T.int32()
C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
for i in range(m):
C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
C_2[i * stride_2] = (A_2[i * stride] + T.float32(1)) * T.float32(2)
compute_root#
compute_root
可以将一个阶段的计算移到 root。
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
stride = T.int32()
A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
stride_1 = T.int32()
B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
stride_2 = T.int32()
C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
for i in range(m):
A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
B_2[i * stride_1] = A_2[i * stride] + T.float32(1)
for i in range(m):
C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
C_2[i * stride_2] = B_2[i * stride_1] * T.float32(2)
小结#
本教程介绍了 tvm 中的调度原语,允许用户轻松灵活地调度计算。
为了得到性能良好的 kernel 实现,一般的工作流程往往是:
通过一系列的运算来描述你的计算。
试着用原语来调度计算。
编译并运行以查看性能差异。
根据运行时的结果调整你的调度。