在 GPU 上自动调优卷积层#
原作者: Lianmin Zheng, Chengfan Jia
这是关于如何使用 GPU 自动调度器的教程。
与基于模板的 autotvm 依赖手动模板定义搜索空间不同,自动调度程序不需要任何调度模板。换句话说,自动调度器只使用 tvm/python/topi
中的 compute,而不使用现有的调度模板。
import os
import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python
定义计算#
定义卷积层的计算。函数应该返回输入/输出张量的列表。从这些张量中,自动调度程序可以得到整个计算图。
@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
data = te.placeholder((N, CI, H, W), name="data")
kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
bias = te.placeholder((1, CO, 1, 1), name="bias")
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
out = topi.nn.relu(conv + bias)
return [data, kernel, bias, out]
创建搜索任务#
然后为 resnet 中的最后一个卷积层创建搜索任务。
target = tvm.target.Target("cuda")
# 使用 ResNet-50 最后一层卷积
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.SearchTask(
func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
)
# 检查计算图
print("Computational DAG:")
print(task.compute_dag)
接下来,为自动调度器设置参数。这些参数主要指定在搜索过程中如何进行测量。
measure_ctx
启动不同的测量进程以提供隔离。它可以在测量期间保护主进程不受 GPU 崩溃的影响,并避免其他运行时冲突。min_repeat_ms
定义每次测量中一次“重复”的最小持续时间。这可以预热 GPU,这对于获得准确的测量结果是必要的。通常,建议值 >= 300 ms。num_measure_trials
是在调优期间可以使用的度量试验的数量。在实践中,建议将它设置在 1000,这通常足以让搜索收敛。可以根据自己的时间预算调整该参数。此外,使用
RecordToFile
将测量记录转储到日志文件中,测量记录可以用于查询历史,恢复搜索,并在以后进行更多的分析。查阅
tvm.auto_scheduler.TuningOptions
、tvm.auto_scheduler.LocalRPCMeasureContext
获取更多参数。
log_file = "conv2d.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10, # change this to 1000 to achieve the best performance
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
运行搜索#
现在准备好所有输入。很简单,不是吗?
可以开始搜索,让自动调度程序发挥它的魔力。经过一些测试之后,可以从日志文件中加载最佳调度并应用它。
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
# Kill the measurement process
del measure_ctx
可以 lower 调度来查看自动调度后的 IR。自动调度器正确地执行优化,包括多级 tiling、cooperative fetching、unrolling和算子融合。
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
检测正确性并评估性能#
构建二进制文件并检查其正确性和性能。
func = tvm.build(sch, args, target)
# Check correctness
data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
out_np = np.maximum(conv_np + bias_np, 0.0)
dev = tvm.cuda()
data_tvm = tvm.nd.array(data_np, device=dev)
weight_tvm = tvm.nd.array(weight_np, device=dev)
bias_tvm = tvm.nd.array(bias_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(data_tvm, weight_tvm, bias_tvm, out_tvm)
# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
# Evaluate execution time
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)
)
使用记录文件#
在搜索过程中,所有测量记录都被转储到记录文件“conv2d.json”中。测量记录可用于重新应用搜索结果、恢复搜索和执行其他分析。
下面的例子,从文件中加载最好的调度,打印等效的 python 调度 API 和 CUDA 源代码。它们可用于调试和学习自动调度器的行为。
print("Equivalent python schedule:")
print(task.print_best(log_file, print_mode="schedule"))
print("CUDA source code:")
print(task.print_best(log_file, print_mode="cuda"))
更复杂的示例是恢复搜索。在这种情况下,需要自己创建搜索策略和代价模型,并通过日志文件恢复搜索策略和代价模型的状态。
在下面的例子中,恢复状态并进行更多的 5 次试验。
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5,
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
task.tune(tune_option, search_policy=search_policy)
# Kill the measurement process
del measure_ctx
resume_search(task, log_file)