扫描和循环 Kernel#
原作者: Tianqi Chen
这是关于如何在 TVM 中进行循环计算的介绍材料。
循环计算是神经网络中的一种典型模式。
import tvm
from tvm import te
import numpy as np
TVM 支持 scan
算子来描述符号循环。
下面的 scan
op 计算 X 列的 cumsum。
scan 在张量的最高维度上进行。s_state
是一个占位符,描述 scan 的变换状态。s_init
描述了如何初始化前 k 个时间步(timestep)。这里由于 s_init
的第一个维度是 1,它描述了如何在第一个时间步初始化状态。
s_update
描述了如何在时间步骤 t 更新值。值可以通过状态占位符引用回前一个时间步的值。注意,在当前或后续的时间步引用 s_state
是无效的。
扫描包含状态占位符、初始值和更新描述。还建议(尽管不是必需的)列出 scan cell 的输入。
扫描的结果是张量,在时域更新后给出 s_state
的结果。
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])
调度 Scan Cell#
可以通过分别调度更新和初始化部分来调度扫描主体(body)。注意,调度更新部分的第一个迭代维度是无效的。要在时间迭代上进行分割,用户可以使用 scan_op.scan_axis
代替。
s = te.create_schedule(s_scan.op)
num_thread = 256
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis("threadIdx.x")
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
tvm.lower(s, [X, s_scan], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(X: T.handle, scan: T.handle):
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
m = T.int32()
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
X_1 = T.match_buffer(X, (m, n), strides=(stride, stride_1), type="auto")
stride_2 = T.int32()
stride_3 = T.int32()
scan_1 = T.match_buffer(scan, (m, n), strides=(stride_2, stride_3), type="auto")
blockIdx_x = T.env_thread("blockIdx.x")
threadIdx_x = T.env_thread("threadIdx.x")
scan_2 = T.Buffer((stride_2 * m,), data=scan_1.data, type="auto")
X_2 = T.Buffer((stride * m,), data=X_1.data, type="auto")
with T.launch_thread(blockIdx_x, (n + 255) // 256):
T.launch_thread(threadIdx_x, 256)
if T.likely(blockIdx_x * 256 + threadIdx_x < n):
scan_2[(blockIdx_x * 256 + threadIdx_x) * stride_3] = X_2[
(blockIdx_x * 256 + threadIdx_x) * stride_1
]
for scan_idx in range(m - 1):
T.launch_thread(blockIdx_x, (n + 255) // 256)
T.launch_thread(threadIdx_x, 256)
if T.likely(blockIdx_x * 256 + threadIdx_x < n):
cse_var_1: T.int32 = scan_idx + 1
scan_2[
cse_var_1 * stride_2 + (blockIdx_x * 256 + threadIdx_x) * stride_3
] = (
scan_2[
scan_idx * stride_2
+ (blockIdx_x * 256 + threadIdx_x) * stride_3
]
+ X_2[
cse_var_1 * stride + (blockIdx_x * 256 + threadIdx_x) * stride_1
]
)
构建并验证#
可以像其他 TVM 内核一样构建扫描内核,这里使用 numpy 来验证结果的正确性。
fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan")
dev = tvm.cuda(0)
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev)
fscan(a, b)
np.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0))
Multi-Stage Scan Cell#
In the above example we described the scan cell using one Tensor computation stage in s_update. It is possible to use multiple Tensor stages in the scan cell.
The following lines demonstrate a scan with two stage operations in the scan cell.
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1")
s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2")
s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])
These intermediate tensors can also be scheduled normally. To ensure correctness, TVM creates a group constraint to forbid the body of scan to be compute_at locations outside the scan loop.
s = te.create_schedule(s_scan.op)
xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32)
s[s_update_s1].compute_at(s[s_update_s2], xo)
print(tvm.lower(s, [X, s_scan], simple_mode=True))
Multiple States#
For complicated applications like RNN, we might need more than one recurrent state. Scan support multiple recurrent states. The following example demonstrates how we can build recurrence with two states.
m = te.var("m")
n = te.var("n")
l = te.var("l")
X = te.placeholder((m, n), name="X")
s_state1 = te.placeholder((m, n))
s_state2 = te.placeholder((m, l))
s_init1 = te.compute((1, n), lambda _, i: X[0, i])
s_init2 = te.compute((1, l), lambda _, i: 0.0)
s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])
s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])
s_scan1, s_scan2 = tvm.te.scan(
[s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]
)
s = te.create_schedule(s_scan1.op)
print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))
Summary#
This tutorial provides a walk through of scan primitive.
Describe scan with init and update.
Schedule the scan cells as normal schedule.
For complicated workload, use multiple states and steps in scan cell.