使用自动调度优化运算#

作者: Lianmin ZhengChengfan Jia

在本教程中,将展示 TVM 的自动调度功能如何在不需要编写自定义模板的情况下找到最佳调度。

与基于模板的 AutoTVM 不同,后者依赖于手动模板来定义搜索空间,而自动调度器不需要任何模板。

用户只需要编写计算声明,而不需要任何调度命令或模板。自动调度器可以自动生成大的搜索空间,并在空间中找到好的调度。

本教程中以矩阵乘法为例。

import numpy as np
import tvm
from tvm import te, auto_scheduler

定义矩阵乘法#

首先,定义带有偏置加法的矩阵乘法。注意,这使用了 TVM 张量表达式语言中的标准运算。主要的区别是在函数定义的顶部使用了 tvm.auto_scheduler.register_workload() 装饰器。该函数应该返回输入/输出张量的列表。从这些张量中,自动调度器可以得到整个计算图。

@auto_scheduler.register_workload  # 注意 auto_scheduler 装饰器
def matmul_add(N, L, M, dtype):
    A = te.placeholder((N, L), name="A", dtype=dtype)
    B = te.placeholder((L, M), name="B", dtype=dtype)
    C = te.placeholder((N, M), name="C", dtype=dtype)

    k = te.reduce_axis((0, L), name="k")
    matmul = te.compute(
        (N, M),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        name="matmul",
        attrs={"layout_free_placeholders": [B]},  # 启用张量 B 的自动布局转换
    )
    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
    return [A, B, C, out]

创建搜索任务#

在定义了函数之后,现在可以为 auto_scheduler 创建任务来进行搜索。指定矩阵乘法的特殊参数,在这个例子中,是对 \(1024 \times 1024\) 大小的正方形矩阵的乘法。然后使用 N=L=M=1024dtype="float32" 创建搜索任务。

用自定义目标提高性能

为了使 TVM 能够充分利用特定的硬件平台,手动指定你的 CPU 能力。例如:

  • llvm -mcpu=core-avx2 替换下面的 llvm,以启用 AVX2

  • llvm -mcpu=skylake-avx512 替换下面的 llvm,以启用 AVX-512

target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)

# 检查计算图
print("Computational DAG:")
print(task.compute_dag)
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])

为自动调度设置参数#

下一步,为自动调度设置参数。

  • num_measure_trials 是在搜索过程中可以使用的测量试验的数量。为了快速演示,在本教程中只做了 10 次试验。在实践中,1000 是个很好的搜索收敛值。你可以根据你的时间预算做更多的试验。

  • 此外,使用 RecordToFile 来 log 测量记录到 matmul.json 文件中。这些测量记录可以用来查询历史最好的,恢复搜索,并在以后做更多的分析。

  • 查阅 TuningOptions 了解参数的更多信息。

log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

运行搜索#

现在把所有的输入准备好。很简单,不是吗?可以启动搜索,让自动调度发挥它的魔力。经过一些测量试验后,可以从日志文件中加载最佳调度并加以应用。

# 运行 auto-tuning (search)
task.tune(tune_option)
# 应用最优 schedule
sch, args = task.apply_best(log_file)
----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Sample Initial Population	#s: 2012	fail_ct: 4	Time elapsed: 2.99
GA Iter: 0	Max score: 0.9999	Min score: 0.9356	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 4	Max score: 1.0000	Min score: 0.9877	#Pop: 128	#M+: 1384	#M-: 79
EvolutionarySearch		#s: 128	Time elapsed: 12.76
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure:
..........**********
==================================================
No: 1	GFLOPS: 125.83 / 125.83	results: MeasureResult(cost:[0.0171], error_no:0, all_cost:0.53, Tstamp:1679472852.34)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@ (0,131072)
  matmul auto_unroll: 64
  for k.0 (0,128)
    for i.2 (0,4)
      for k.1 (0,8)
        for i.3 (0,2)
          matmul = ...
  for i.2 (0,8)
    out = ...

==================================================
No: 2	GFLOPS: 6.94 / 125.83	results: MeasureResult(cost:[0.3098], error_no:0, all_cost:2.33, Tstamp:1679472853.71)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,2)
  matmul auto_unroll: 64
  for i.1 (0,4)
    for j.1 (0,256)
      for k.0 (0,128)
        for i.2 (0,4)
          for k.1 (0,8)
            for i.3 (0,32)
              vectorize j.3 (0,4)
                matmul = ...
  for i.1 (0,512)
    for j.1 (0,1024)
      out = ...

==================================================
No: 3	GFLOPS: 256.41 / 256.41	results: MeasureResult(cost:[0.0084], error_no:0, all_cost:0.69, Tstamp:1679472854.15)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,8192)
  for i.1 (0,2)
    matmul auto_unroll: 64
    for k.0 (0,512)
      for k.1 (0,2)
        for i.3 (0,4)
          vectorize j.3 (0,16)
            matmul = ...
    for i.2 (0,4)
      vectorize j.2 (0,16)
        out = ...

==================================================
No: 4	GFLOPS: 82.04 / 256.41	results: MeasureResult(cost:[0.0262], error_no:0, all_cost:0.70, Tstamp:1679472854.52)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,16)
  matmul auto_unroll: 64
  for i.1 (0,4)
    for j.1 (0,8)
      for k.0 (0,32)
        for j.2 (0,16)
          for k.1 (0,32)
            for i.3 (0,32)
              vectorize j.3 (0,4)
                matmul = ...
  for i.1 (0,128)
    for j.1 (0,512)
      out = ...

==================================================
No: 5	GFLOPS: 250.11 / 256.41	results: MeasureResult(cost:[0.0086], error_no:0, all_cost:0.83, Tstamp:1679472855.12)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,1024)
  for j.1 (0,4)
    for k.0 (0,256)
      for i.2 (0,8)
        for j.2 (0,4)
          for k.1 (0,4)
            for i.3 (0,2)
              vectorize j.3 (0,4)
                matmul = ...
    for i.2 (0,16)
      vectorize j.2 (0,16)
        out = ...

==================================================
No: 6	GFLOPS: 160.68 / 256.41	results: MeasureResult(cost:[0.0134], error_no:0, all_cost:0.61, Tstamp:1679472855.42)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 16
parallel i.0@j.0@i.1@j.1@ (0,8192)
  for k.0 (0,256)
    for i.2 (0,2)
      for k.1 (0,4)
        for i.3 (0,2)
          for j.3 (0,32)
            matmul = ...
parallel i (0,1024)
  for j (0,1024)
    out = ...

==================================================
No: 7	GFLOPS: 48.42 / 256.41	results: MeasureResult(cost:[0.0444], error_no:0, all_cost:0.81, Tstamp:1679472855.74)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,128)
  matmul auto_unroll: 512
  for i.1 (0,8)
    for k.0 (0,32)
      for i.2 (0,32)
        for j.2 (0,32)
          for k.1 (0,32)
            matmul = ...
  for i.1 (0,256)
    for j.1 (0,32)
      out = ...

==================================================
No: 8	GFLOPS: 20.51 / 256.41	results: MeasureResult(cost:[0.1047], error_no:0, all_cost:0.69, Tstamp:1679472856.31)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,128)
  for k.0 (0,64)
    for j.2 (0,1024)
      for k.1 (0,16)
        for i.3 (0,8)
          matmul = ...
parallel i (0,1024)
  for j (0,1024)
    out = ...

==================================================
No: 9	GFLOPS: 90.08 / 256.41	results: MeasureResult(cost:[0.0239], error_no:0, all_cost:0.52, Tstamp:1679472856.66)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 16
parallel i.0@j.0@i.1@ (0,4096)
  for j.1 (0,8)
    for k.0 (0,32)
      for i.2 (0,4)
        for k.1 (0,32)
          for i.3 (0,8)
            matmul = ...
parallel i (0,1024)
  for j (0,1024)
    out = ...

==================================================
No: 10	GFLOPS: 4.71 / 256.41	results: MeasureResult(cost:[0.4557], error_no:0, all_cost:2.05, Tstamp:1679472858.62)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 64
parallel i.0@j.0@i.1@j.1@ (0,2048)
  for k.0 (0,1024)
    for i.2 (0,512)
      matmul = ...
parallel i (0,1024)
  for j (0,1024)
    out = ...

Time elapsed for measurement: 11.32 s
----------------------------------------------------------------------
------------------------------  [ Done ]
----------------------------------------------------------------------

检查优化后的调度#

可以 lower 调度,看看自动调度后的 IR。自动调度器正确地进行了优化,包括多级平铺(tiling)、布局转换(layout transformation)、并行化(parallelization)、矢量化(vectorization)、解卷(unrolling)和运算符融合(operator fusion)。

mod = tvm.lower(sch, args, simple_mode=True)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(
        A: T.Buffer((1024, 1024), "float32"),
        B: T.Buffer((1024, 1024), "float32"),
        C: T.Buffer((1024, 1024), "float32"),
        out: T.Buffer((1024, 1024), "float32"),
    ):
        T.func_attr(
            {
                "from_legacy_te_schedule": True,
                "global_symbol": "main",
                "tir.noalias": True,
            }
        )
        auto_scheduler_layout_transform = T.allocate([1048576], "float32", "global")
        auto_scheduler_layout_transform_1 = T.Buffer(
            (1048576,), data=auto_scheduler_layout_transform
        )
        for ax0_ax1_fused_ax2_fused in T.parallel(256):
            for ax4, ax6, ax7 in T.grid(512, 2, 4):
                B_1 = T.Buffer((1048576,), data=B.data)
                auto_scheduler_layout_transform_1[
                    ax0_ax1_fused_ax2_fused * 4096 + ax4 * 8 + ax6 * 4 + ax7
                ] = B_1[ax4 * 2048 + ax6 * 1024 + ax0_ax1_fused_ax2_fused * 4 + ax7]
        for i_outer_outer_j_outer_outer_fused in T.parallel(2048):
            matmul = T.allocate([32], "float32", "global")
            for j_outer_inner in range(16):
                matmul_1 = T.Buffer((32,), data=matmul)
                matmul_1[0:4] = T.Broadcast(T.float32(0), 4)
                matmul_1[4:8] = T.Broadcast(T.float32(0), 4)
                matmul_1[8:12] = T.Broadcast(T.float32(0), 4)
                matmul_1[12:16] = T.Broadcast(T.float32(0), 4)
                matmul_1[16:20] = T.Broadcast(T.float32(0), 4)
                matmul_1[20:24] = T.Broadcast(T.float32(0), 4)
                matmul_1[24:28] = T.Broadcast(T.float32(0), 4)
                matmul_1[28:32] = T.Broadcast(T.float32(0), 4)
                for k_outer in range(512):
                    cse_var_3: T.int32 = (
                        i_outer_outer_j_outer_outer_fused // 16 * 8192 + k_outer * 2
                    )
                    cse_var_2: T.int32 = (
                        i_outer_outer_j_outer_outer_fused % 16 * 65536
                        + j_outer_inner * 4096
                        + k_outer * 8
                    )
                    cse_var_1: T.int32 = cse_var_2 + 4
                    A_1 = T.Buffer((1048576,), data=A.data)
                    matmul_1[0:4] = (
                        matmul_1[0:4]
                        + T.Broadcast(A_1[cse_var_3], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[0:4] = (
                        matmul_1[0:4]
                        + T.Broadcast(A_1[cse_var_3 + 1], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[4:8] = (
                        matmul_1[4:8]
                        + T.Broadcast(A_1[cse_var_3 + 1024], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[4:8] = (
                        matmul_1[4:8]
                        + T.Broadcast(A_1[cse_var_3 + 1025], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[8:12] = (
                        matmul_1[8:12]
                        + T.Broadcast(A_1[cse_var_3 + 2048], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[8:12] = (
                        matmul_1[8:12]
                        + T.Broadcast(A_1[cse_var_3 + 2049], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[12:16] = (
                        matmul_1[12:16]
                        + T.Broadcast(A_1[cse_var_3 + 3072], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[12:16] = (
                        matmul_1[12:16]
                        + T.Broadcast(A_1[cse_var_3 + 3073], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[16:20] = (
                        matmul_1[16:20]
                        + T.Broadcast(A_1[cse_var_3 + 4096], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[16:20] = (
                        matmul_1[16:20]
                        + T.Broadcast(A_1[cse_var_3 + 4097], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[20:24] = (
                        matmul_1[20:24]
                        + T.Broadcast(A_1[cse_var_3 + 5120], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[20:24] = (
                        matmul_1[20:24]
                        + T.Broadcast(A_1[cse_var_3 + 5121], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[24:28] = (
                        matmul_1[24:28]
                        + T.Broadcast(A_1[cse_var_3 + 6144], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[24:28] = (
                        matmul_1[24:28]
                        + T.Broadcast(A_1[cse_var_3 + 6145], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                    matmul_1[28:32] = (
                        matmul_1[28:32]
                        + T.Broadcast(A_1[cse_var_3 + 7168], 4)
                        * auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 + 4]
                    )
                    matmul_1[28:32] = (
                        matmul_1[28:32]
                        + T.Broadcast(A_1[cse_var_3 + 7169], 4)
                        * auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 + 4]
                    )
                for i_inner, j_inner in T.grid(8, 4):
                    cse_var_4: T.int32 = (
                        i_outer_outer_j_outer_outer_fused // 16 * 8192
                        + i_inner * 1024
                        + i_outer_outer_j_outer_outer_fused % 16 * 64
                        + j_outer_inner * 4
                        + j_inner
                    )
                    out_1 = T.Buffer((1048576,), data=out.data)
                    C_1 = T.Buffer((1048576,), data=C.data)
                    out_1[cse_var_4] = matmul_1[i_inner * 4 + j_inner] + C_1[cse_var_4]

检查正确性并评估性能#

建立二进制文件,并检查其正确性(correctness)和性能(performance)。

func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np

dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
Execution time of this operator: 5.362 ms

使用纪录文件#

在搜索过程中,所有的测量记录都被 log 到记录文件 matmul.json。这些测量记录可以用来重新应用搜索结果,恢复搜索,并进行其他分析。

这里有一个例子,我们从一个文件中加载最佳调度,并打印出等效的 python 调度 API。这可以用于调试和学习自动调度的行为。

print("Equivalent python schedule:")
print(task.print_best(log_file))
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=1)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=8)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=1)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=4)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=16)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=2)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=8)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=1)
out_j_o_i, out_j_i = s[out].split(out_j, factor=4)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=16)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 512)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)

更复杂的例子是恢复搜索。在这种情况下,需要自己创建搜索策略和成本模型,并通过日志文件恢复搜索策略和成本模型(cost model)的状态。在下面的例子中,恢复了状态并做了更多的 5 次试验。

def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
    )
    task.tune(tune_option, search_policy=search_policy)

resume_search(task, log_file)
Resume search:
----------------------------------------------------------------------
------------------------------  [ Call init-search callbacks ]
----------------------------------------------------------------------
SearchPolicy: Loaded 25 measurement records from matmul.json for ["matmul_add", 1024, 1024, 1024, "float32"]
----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Sample Initial Population	#s: 2013	fail_ct: 6	Time elapsed: 2.74
GA Iter: 0	Max score: 0.9995	Min score: 0.9315	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 4	Max score: 0.9998	Min score: 0.9862	#Pop: 128	#M+: 1373	#M-: 69
EvolutionarySearch		#s: 128	Time elapsed: 12.14
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 5 programs to measure:
.....*****
Time elapsed for measurement: 6.14 s
----------------------------------------------------------------------
------------------------------  [ Done ]
----------------------------------------------------------------------

最后说明和总结#

在本教程中,已经展示了如何使用 TVM 自动调度器来自动优化矩阵乘法,而不需要指定搜索模板。它结束了一系列从张量表达式(Tensor Expression,简称 TE)语言开始的例子,展示了 TVM 如何优化算子计算。