扫描和循环 Kernel#

原作者: Tianqi Chen

这是关于如何在 TVM 中进行循环计算的介绍材料。

循环计算是神经网络中的一种典型模式。

import tvm
from tvm import te
import numpy as np

TVM 支持 scan 算子来描述符号循环。 下面的 scan op 计算 X 列的 cumsum。

scan 在张量的最高维度上进行。s_state 是一个占位符,描述 scan 的变换状态。s_init 描述了如何初始化前 k 个时间步(timestep)。这里由于 s_init 的第一个维度是 1,它描述了如何在第一个时间步初始化状态。

s_update 描述了如何在时间步骤 t 更新值。值可以通过状态占位符引用回前一个时间步的值。注意,在当前或后续的时间步引用 s_state 是无效的。

扫描包含状态占位符、初始值和更新描述。还建议(尽管不是必需的)列出 scan cell 的输入。 扫描的结果是张量,在时域更新后给出 s_state 的结果。

m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])

调度 Scan Cell#

可以通过分别调度更新和初始化部分来调度扫描主体(body)。注意,调度更新部分的第一个迭代维度是无效的。要在时间迭代上进行分割,用户可以使用 scan_op.scan_axis 代替。

s = te.create_schedule(s_scan.op)
num_thread = 256
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis("threadIdx.x")
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
tvm.lower(s, [X, s_scan], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(X: T.handle, scan: T.handle):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        m = T.int32()
        n = T.int32()
        stride = T.int32()
        stride_1 = T.int32()
        X_1 = T.match_buffer(X, (m, n), strides=(stride, stride_1), type="auto")
        stride_2 = T.int32()
        stride_3 = T.int32()
        scan_1 = T.match_buffer(scan, (m, n), strides=(stride_2, stride_3), type="auto")
        blockIdx_x = T.env_thread("blockIdx.x")
        threadIdx_x = T.env_thread("threadIdx.x")
        scan_2 = T.Buffer((stride_2 * m,), data=scan_1.data, type="auto")
        X_2 = T.Buffer((stride * m,), data=X_1.data, type="auto")
        with T.launch_thread(blockIdx_x, (n + 255) // 256):
            T.launch_thread(threadIdx_x, 256)
            if T.likely(blockIdx_x * 256 + threadIdx_x < n):
                scan_2[(blockIdx_x * 256 + threadIdx_x) * stride_3] = X_2[
                    (blockIdx_x * 256 + threadIdx_x) * stride_1
                ]
        for scan_idx in range(m - 1):
            T.launch_thread(blockIdx_x, (n + 255) // 256)
            T.launch_thread(threadIdx_x, 256)
            if T.likely(blockIdx_x * 256 + threadIdx_x < n):
                cse_var_1: T.int32 = scan_idx + 1
                scan_2[
                    cse_var_1 * stride_2 + (blockIdx_x * 256 + threadIdx_x) * stride_3
                ] = (
                    scan_2[
                        scan_idx * stride_2
                        + (blockIdx_x * 256 + threadIdx_x) * stride_3
                    ]
                    + X_2[
                        cse_var_1 * stride + (blockIdx_x * 256 + threadIdx_x) * stride_1
                    ]
                )

构建并验证#

可以像其他 TVM 内核一样构建扫描内核,这里使用 numpy 来验证结果的正确性。

fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan")
dev = tvm.cuda(0)
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev)
fscan(a, b)
np.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0))

Multi-Stage Scan Cell#

In the above example we described the scan cell using one Tensor computation stage in s_update. It is possible to use multiple Tensor stages in the scan cell.

The following lines demonstrate a scan with two stage operations in the scan cell.

m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1")
s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2")
s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])

These intermediate tensors can also be scheduled normally. To ensure correctness, TVM creates a group constraint to forbid the body of scan to be compute_at locations outside the scan loop.

s = te.create_schedule(s_scan.op)
xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32)
s[s_update_s1].compute_at(s[s_update_s2], xo)
print(tvm.lower(s, [X, s_scan], simple_mode=True))

Multiple States#

For complicated applications like RNN, we might need more than one recurrent state. Scan support multiple recurrent states. The following example demonstrates how we can build recurrence with two states.

m = te.var("m")
n = te.var("n")
l = te.var("l")
X = te.placeholder((m, n), name="X")
s_state1 = te.placeholder((m, n))
s_state2 = te.placeholder((m, l))
s_init1 = te.compute((1, n), lambda _, i: X[0, i])
s_init2 = te.compute((1, l), lambda _, i: 0.0)
s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])
s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])
s_scan1, s_scan2 = tvm.te.scan(
    [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]
)
s = te.create_schedule(s_scan1.op)
print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))

Summary#

This tutorial provides a walk through of scan primitive.

  • Describe scan with init and update.

  • Schedule the scan cells as normal schedule.

  • For complicated workload, use multiple states and steps in scan cell.