TVM 中的调度原语#

原作者: Ziheng Jiang

TVM 用于高效构建 kernel 的领域特定语言。

在本教程中,将您展示如何通过 TVM 提供的各种原语调度计算。

import tvm
from tvm import te
import numpy as np

通常有几种方法可以计算相同的结果,但是,不同的方法会导致不同的局部性(locality)和性能。因此 TVM 要求用户提供如何执行名为 Schedule (调度)的计算。

Schedule 是一组用于变换程序中计算循环的计算变换。

# 声明一些变量以备以后使用
n = te.var("n")
m = te.var("m")
p = te.var("p")

调度可以从 ops 列表中创建,默认情况下,调度以 row-major 顺序的串行方式计算张量。

# 声明矩阵元素级的乘法
A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")
s = te.create_schedule([C.op])

lower 将计算从定义转换为实际的可调用函数。使用 simple_mode=True 参数,它将返回可读的 C like 语句,在这里使用它来打印调度结果。

tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        n = T.int32()
        stride = T.int32()
        stride_1 = T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
        stride_2 = T.int32()
        stride_3 = T.int32()
        B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
        stride_4 = T.int32()
        stride_5 = T.int32()
        C_1 = T.match_buffer(C, (m, n), strides=(stride_4, stride_5), type="auto")
        for i, j in T.grid(m, n):
            C_2 = T.Buffer((stride_4 * m,), data=C_1.data, type="auto")
            A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
            B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
            C_2[i * stride_4 + j * stride_5] = (
                A_2[i * stride + j * stride_1] * B_2[i * stride_2 + j * stride_3]
            )

每个调度由多个阶段(Stage)组成,每个阶段表示一个运算的调度。

下面提供各种方法来调度每个阶段。

split#

split 可以通过 factor 将指定的轴分裂(split)为两个轴。

m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] * 2, name="B")

s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
tvm.lower(s, [A, B]).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
        B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
        A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
        for i_outer, i_inner in T.grid(m // 32, 32):
            cse_var_1: T.int32 = i_outer * 32 + i_inner
            B_2[cse_var_1 * stride_1] = A_2[cse_var_1 * stride] * T.float32(2)
        for i_outer, i_inner in T.grid((m % 32 + 31) // 32, 32):
            if m // 32 * 32 + i_outer * 32 + i_inner < m:
                B_2[(m // 32 * 32 + i_outer * 32 + i_inner) * stride_1] = A_2[
                    (m // 32 * 32 + i_outer * 32 + i_inner) * stride
                ] * T.float32(2)

你也可以通过 nparts 分裂轴,它与 factor 分割轴相对。

m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i], name="B")

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
        for i_outer, i_inner in T.grid(32, (m + 31) // 32):
            if T.likely(i_inner + i_outer * ((m + 31) // 32) < m):
                B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
                A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
                B_2[(i_inner + i_outer * ((m + 31) // 32)) * stride_1] = A_2[
                    (i_inner + i_outer * ((m + 31) // 32)) * stride
                ]

tile#

tile 帮助你在两个轴上逐块(tile by tile)执行计算。

A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")

s = te.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        n = T.int32()
        stride = T.int32()
        stride_1 = T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
        stride_2 = T.int32()
        stride_3 = T.int32()
        B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
        for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10):
            if T.likely(i_outer * 10 + i_inner < m):
                for j_inner in range(5):
                    if T.likely(j_outer * 5 + j_inner < n):
                        cse_var_2: T.int32 = j_outer * 5 + j_inner
                        cse_var_1: T.int32 = i_outer * 10 + i_inner
                        B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
                        A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
                        B_2[cse_var_1 * stride_2 + cse_var_2 * stride_3] = A_2[
                            cse_var_1 * stride + cse_var_2 * stride_1
                        ]

fuse#

fuse 可以融合一个计算的两个连续轴。

A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")

s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        n = T.int32()
        stride = T.int32()
        stride_1 = T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
        stride_2 = T.int32()
        stride_3 = T.int32()
        B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
        for i_outer, j_outer, i_inner_j_inner_fused in T.grid(
            (m + 9) // 10, (n + 4) // 5, 50
        ):
            if T.likely(i_outer * 10 + i_inner_j_inner_fused // 5 < m):
                if T.likely(j_outer * 5 + i_inner_j_inner_fused % 5 < n):
                    cse_var_2: T.int32 = j_outer * 5 + i_inner_j_inner_fused % 5
                    cse_var_1: T.int32 = i_outer * 10 + i_inner_j_inner_fused // 5
                    B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
                    A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
                    B_2[cse_var_1 * stride_2 + cse_var_2 * stride_3] = A_2[
                        cse_var_1 * stride + cse_var_2 * stride_1
                    ]

reorder#

reorder 可以按指定的顺序重新排列坐标轴。

A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")

s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axes: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        n = T.int32()
        stride = T.int32()
        stride_1 = T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=(stride, stride_1), type="auto")
        stride_2 = T.int32()
        stride_3 = T.int32()
        B_1 = T.match_buffer(B, (m, n), strides=(stride_2, stride_3), type="auto")
        for i_inner, j_outer, i_outer in T.grid(10, (n + 4) // 5, (m + 9) // 10):
            if T.likely(i_outer * 10 + i_inner < m):
                for j_inner in range(5):
                    if T.likely(j_outer * 5 + j_inner < n):
                        cse_var_2: T.int32 = j_outer * 5 + j_inner
                        cse_var_1: T.int32 = i_outer * 10 + i_inner
                        B_2 = T.Buffer((stride_2 * m,), data=B_1.data, type="auto")
                        A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
                        B_2[cse_var_1 * stride_2 + cse_var_2 * stride_3] = A_2[
                            cse_var_1 * stride + cse_var_2 * stride_1
                        ]

bind#

bind 可以将指定的轴与线程轴绑定,通常用于 gpu 编程。

A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: A[i] * 2, name="B")

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
tvm.lower(s, [A, B], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        n = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (n,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (n,), strides=(stride_1,), type="auto")
        blockIdx_x = T.env_thread("blockIdx.x")
        T.launch_thread(blockIdx_x, (n + 63) // 64)
        threadIdx_x = T.env_thread("threadIdx.x")
        T.launch_thread(threadIdx_x, 64)
        if T.likely(blockIdx_x * 64 + threadIdx_x < n):
            B_2 = T.Buffer((stride_1 * n,), data=B_1.data, type="auto")
            A_2 = T.Buffer((stride * n,), data=A_1.data, type="auto")
            B_2[(blockIdx_x * 64 + threadIdx_x) * stride_1] = A_2[
                (blockIdx_x * 64 + threadIdx_x) * stride
            ] * T.float32(2)

compute_at#

对于由多个算子组成的调度,默认情况下 TVM 将分别计算根节点上的张量。

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
        stride_2 = T.int32()
        C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
        B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
        for i in range(m):
            A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
            B_2[i * stride_1] = A_2[i * stride] + T.float32(1)
        for i in range(m):
            C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
            C_2[i * stride_2] = B_2[i * stride_1] * T.float32(2)

compute_at 可以将 B 的计算移到 C 的计算的第一个轴上。

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
        stride_2 = T.int32()
        C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
        for i in range(m):
            B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
            A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
            B_2[i * stride_1] = A_2[i * stride] + T.float32(1)
            C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
            C_2[i * stride_2] = B_2[i * stride_1] * T.float32(2)

compute_inline#

compute_inline 可以将一个阶段标记为内联,然后将计算体扩展并插入到需要张量的地址。

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_inline()
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
        stride_2 = T.int32()
        C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
        for i in range(m):
            C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
            A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
            C_2[i * stride_2] = (A_2[i * stride] + T.float32(1)) * T.float32(2)

compute_root#

compute_root 可以将一个阶段的计算移到 root。

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
tvm.lower(s, [A, B, C], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        stride = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=(stride,), type="auto")
        stride_1 = T.int32()
        B_1 = T.match_buffer(B, (m,), strides=(stride_1,), type="auto")
        stride_2 = T.int32()
        C_1 = T.match_buffer(C, (m,), strides=(stride_2,), type="auto")
        B_2 = T.Buffer((stride_1 * m,), data=B_1.data, type="auto")
        for i in range(m):
            A_2 = T.Buffer((stride * m,), data=A_1.data, type="auto")
            B_2[i * stride_1] = A_2[i * stride] + T.float32(1)
        for i in range(m):
            C_2 = T.Buffer((stride_2 * m,), data=C_1.data, type="auto")
            C_2[i * stride_2] = B_2[i * stride_1] * T.float32(2)

小结#

本教程介绍了 tvm 中的调度原语,允许用户轻松灵活地调度计算。

为了得到性能良好的 kernel 实现,一般的工作流程往往是:

  • 通过一系列的运算来描述你的计算。

  • 试着用原语来调度计算。

  • 编译并运行以查看性能差异。

  • 根据运行时的结果调整你的调度。