在 GPU 上自动调优卷积层#

原作者: Lianmin Zheng, Chengfan Jia

这是关于如何使用 GPU 自动调度器的教程。

与基于模板的 autotvm 依赖手动模板定义搜索空间不同,自动调度程序不需要任何调度模板。换句话说,自动调度器只使用 tvm/python/topi 中的 compute,而不使用现有的调度模板。

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python

定义计算#

定义卷积层的计算。函数应该返回输入/输出张量的列表。从这些张量中,自动调度程序可以得到整个计算图。

@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
    bias = te.placeholder((1, CO, 1, 1), name="bias")
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
    out = topi.nn.relu(conv + bias)
    return [data, kernel, bias, out]

创建搜索任务#

然后为 resnet 中的最后一个卷积层创建搜索任务。

target = tvm.target.Target("cuda")

# 使用 ResNet-50 最后一层卷积
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.SearchTask(
    func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
)

# 检查计算图
print("Computational DAG:")
print(task.compute_dag)

接下来,为自动调度器设置参数。这些参数主要指定在搜索过程中如何进行测量。

  • measure_ctx 启动不同的测量进程以提供隔离。它可以在测量期间保护主进程不受 GPU 崩溃的影响,并避免其他运行时冲突。

  • min_repeat_ms 定义每次测量中一次“重复”的最小持续时间。这可以预热 GPU,这对于获得准确的测量结果是必要的。通常,建议值 >= 300 ms。

  • num_measure_trials 是在调优期间可以使用的度量试验的数量。在实践中,建议将它设置在 1000,这通常足以让搜索收敛。可以根据自己的时间预算调整该参数。

  • 此外,使用 RecordToFile 将测量记录转储到日志文件中,测量记录可以用于查询历史,恢复搜索,并在以后进行更多的分析。

  • 查阅 tvm.auto_scheduler.TuningOptionstvm.auto_scheduler.LocalRPCMeasureContext 获取更多参数。

log_file = "conv2d.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

运行搜索#

现在准备好所有输入。很简单,不是吗?

可以开始搜索,让自动调度程序发挥它的魔力。经过一些测试之后,可以从日志文件中加载最佳调度并应用它。

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

# Kill the measurement process
del measure_ctx

可以 lower 调度来查看自动调度后的 IR。自动调度器正确地执行优化,包括多级 tiling、cooperative fetching、unrolling和算子融合。

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

检测正确性并评估性能#

构建二进制文件并检查其正确性和性能。

func = tvm.build(sch, args, target)

# Check correctness
data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
out_np = np.maximum(conv_np + bias_np, 0.0)

dev = tvm.cuda()
data_tvm = tvm.nd.array(data_np, device=dev)
weight_tvm = tvm.nd.array(weight_np, device=dev)
bias_tvm = tvm.nd.array(bias_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(data_tvm, weight_tvm, bias_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)
)

使用记录文件#

在搜索过程中,所有测量记录都被转储到记录文件“conv2d.json”中。测量记录可用于重新应用搜索结果、恢复搜索和执行其他分析。

下面的例子,从文件中加载最好的调度,打印等效的 python 调度 API 和 CUDA 源代码。它们可用于调试和学习自动调度器的行为。

print("Equivalent python schedule:")
print(task.print_best(log_file, print_mode="schedule"))

print("CUDA source code:")
print(task.print_best(log_file, print_mode="cuda"))

更复杂的示例是恢复搜索。在这种情况下,需要自己创建搜索策略和代价模型,并通过日志文件恢复搜索策略和代价模型的状态。

在下面的例子中,恢复状态并进行更多的 5 次试验。

def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5,
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )
    task.tune(tune_option, search_policy=search_policy)

    # Kill the measurement process
    del measure_ctx


resume_search(task, log_file)