PyTorch编译器示例教程#
本教程旨在涵盖 PyTorch 编译器的以下几个方面:
基本概念(即时(Just-In-Time)编译器,提前(Ahead-of-time)编译器)
Dynamo(图捕获,将用户的代码分为纯 Python 代码和纯 PyTorch 相关代码)
AOTAutograd(从正向计算图中生成反向计算图)
Inductor/其他后端(给定计算图,如何在不同的设备上更快地运行它)
这些组件将根据不同的后端选项被调用:
当只使用 Dynamo 时,使用
torch.compile(backend="eager")
。当使用 Dynamo 和 AOTAutograd 时,使用
torch.compile(backend="aot_eager")
。默认情况下,使用
torch.compile(backend="inductor")
,这意味着同时使用 Dynamo、AOTAutograd 以及 PyTorch 内置的图优化后端 Inductor。
PyTorch 编译器是即时编译器#
首先需要了解的概念是,PyTorch 编译器是一种即时编译器(Just-In-Time)。那么,即时编译器是什么意思呢?来看例子:
import torch
class A(torch.nn.Module):
def forward(self, x):
return torch.exp(2 * x)
class B(torch.nn.Module):
def forward(self, x):
return torch.exp(-x)
def f(x, mod):
y = mod(x)
z = torch.log(y)
return z
# users might use
# mod = A()
# x = torch.randn(5, 5, 5)
# output = f(x, mod)
编写了函数 f
,它包含模块调用,该调用将执行 mod.forward
,以及 torch.log
调用。由于众所周知的代数简化恒等式 \(\log(\exp(a\times x))=a\times x\),迫不及待地想要优化代码如下:
def f(x, mod):
if isinstance(mod, A):
return 2 * x
elif isinstance(mod, B):
return -x
可以将其称为我们的第一个编译器,尽管它是由我们的大脑而不是自动化程序编译的。
如果希望更加严谨,那么编译器示例应该更新如下:
def f(x, mod):
if isinstance(x, torch.Tensor) and isinstance(mod, A):
return 2 * x
elif isinstance(x, torch.Tensor) and isinstance(mod, B):
return -x
else:
y = mod(x)
z = torch.log(y)
return z
需要检查每个参数,以确保优化条件是合理的,如果未能优化代码,还需要回退到原始代码。
这引出了即时编译器中的两个基本概念:守卫和转换代码。守卫是函数可以被优化的条件,而 转换代码 则是在满足守卫条件下的函数优化版本。在上面简单的编译器示例中,isinstance(mod, A)
就是守卫,而 return 2 * x
则是相应的转换代码,它在守卫条件下与原始代码等效,但执行速度要快得多。
上述例子是提前编译的编译器:检查所有可用的源代码,并在运行任何函数(即提前)之前,根据所有可能的守卫和转换代码编写优化后的函数。
另一类编译器是即时编译器:就在函数执行之前,它会分析是否可以对执行进行优化,以及在什么条件下可以对函数执行进行优化。希望这个条件足够通用,以适应新的输入,从而使即时编译的好处大于成本。如果所有条件都失败,它将尝试在新的条件下优化代码。
即时编译器的基本工作流程应该如下所示:
def f(x, mod):
for guard, transformed_code in f.compiled_entries:
if guard(x, mod):
return transformed_code(x, mod)
try:
guard, transformed_code = compile_and_optimize(x, mod)
f.compiled_entries.append([guard, transformed_code])
return transformed_code(x, mod)
except FailToCompileError:
y = mod(x)
z = torch.log(y)
return z
即时编译器(Just-In-Time Compiler)仅针对其已经观察到的情况进行优化。每当它遇到新的输入,而这个输入不满足任何现有的保护条件时,它就会为这个新输入编译出新的保护条件和转换后的代码。
逐步解释编译器的状态(就保护条件和转换后的代码而言):
import torch
class A(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.exp(2 * x)
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.exp(-x)
@just_in_time_compile # an imaginary compiler function
def f(x, mod):
y = mod(x)
z = torch.log(y)
return z
a = A()
b = B()
x = torch.randn((5, 5, 5))
# before executing f(x, a), f.compiled_entries == [] is empty.
f(x, a)
# after executing f(x, a), f.compiled_entries == [Guard("isinstance(x, torch.Tensor) and isinstance(mod, A)"), TransformedCode("return 2 * x")]
# the second call of f(x, a) hit a condition, so we can just execute the transformed code
f(x, a)
# f(x, b) will trigger compilation and add a new compiled entry
# before executing f(x, b), f.compiled_entries == [Guard("isinstance(x, torch.Tensor) and isinstance(mod, A)"), TransformedCode("return 2 * x")]
f(x, b)
# after executing f(x, b), f.compiled_entries == [Guard("isinstance(x, torch.Tensor) and isinstance(mod, A)"), TransformedCode("return 2 * x"), Guard("isinstance(x, torch.Tensor) and isinstance(mod, B)"), TransformedCode("return -x")]
# the second call of f(x, b) hit a condition, so we can just execute the transformed code
f(x, b)
在这个示例中,我们对类类型进行防护检查,例如使用 isinstance(mod, A)
语句,而且转换后的代码仍然是 Python 代码;对于 torch.compile 来说,它需要对更多的条件进行防护,比如设备(CPU/GPU)、数据类型(int32, float32)、形状([10], [8]),而它的转换代码则是 Python 字节码。我们可以从函数中提取这些编译条目,更多细节请参阅 PyTorch 文档。尽管在防护和转换代码方面有所不同,但 torch.compile
的基本工作流程与本例相同,即它充当即时编译器。
超越代数简化的优化#
上述例子是关于代数简化的。然而,这样的优化在实践中相当罕见。让我们来看更实际的例子,并了解 PyTorch 编译器是如何对以下代码进行优化的:
import torch
@torch.compile
def function(inputs):
x = inputs["x"]
y = inputs["y"]
x = x.cos().cos()
if x.mean() > 0.5:
x = x / 1.1
return x * y
shape_10_inputs = {"x": torch.randn(10, requires_grad=True), "y": torch.randn(10, requires_grad=True)}
shape_8_inputs = {"x": torch.randn(8, requires_grad=True), "y": torch.randn(8, requires_grad=True)}
# warmup
for i in range(100):
output = function(shape_10_inputs)
output = function(shape_8_inputs)
# execution of compiled functions
output = function(shape_10_inputs)
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
代码尝试实现 \(\text{cos}(\text{cos}(x))\) 激活函数,并根据其激活值调整输出的大小,然后将输出与另一个张量 y
相乘。
Dynamo是如何转换和修改函数功能的?#
当理解 torch.compile()
作为即时编译器的整体图景后,可以更深入地探究其工作原理。与 gcc
或 llvm
这样的通用编译器不同,torch.compile()
是特定领域的编译器:它只专注于 PyTorch 相关的计算图。因此,需要工具来将用户的代码分为两部分:纯 Python 代码和计算图代码。
Dynamo 就位于 torch._dynamo
模块内,是完成此任务的工具。通常不直接与这个模块交互。它是在 torch.compile()
函数内部被调用的。
从概念上讲,Dynamo 执行以下操作:
找到第一个无法在计算图中表示但需要计算图中计算值的算子(例如,打印张量的值,使用张量的值来决定 Python 中的
if
语句控制流)。将前面的算子分成两部分:一个是纯粹关于张量计算的计算图,另一个是一些关于操纵 Python 对象的 Python 代码。
将剩余的算子保留为一两个新函数(称为
resume
函数),并再次触发上述分析。
为了能够对函数进行这种细粒度的操作,Dynamo 在低于 Python 源代码级别的 Python 字节码层面运作。
以下过程描述了 Dynamo 对函数所做的处理。
Dynamo
的显著特性是它能够分析函数内部调用的所有函数。如果函数可以完全用计算图表示,那么这个函数的调用将被内联,从而消除该函数调用。
Dynamo 的使命是以安全稳妥的方式从 Python 代码中提取计算图。一旦获得了计算图,就可以进入计算图优化的世界。
备注
上述工作流程包含许多难以理解的字节码。对于那些无法阅读 Python 字节码的人来说,depyf
可以提供帮助!
动态形状支持来自 Dynamo
#
深度学习编译器通常倾向于静态形状输入。这就是为什么上述保护条件包括形状保护的原因。第一次函数调用使用形状 [10]
的输入,但第二次函数调用使用的是形状 [8]
的输入。这将无法通过形状保护,因此触发新的代码转换。
默认情况下,Dynamo
支持动态形状。当形状保护失败时,它会分析和比较形状,并尝试将形状泛化。在这种情况下,看到形状为 [8]
的输入后,它将尝试泛化为任意一维形状 [s0]
,这被称为动态形状或符号形状。
AOTAutograd
:从前向图生成反向计算图#
上述代码仅处理前向计算图。重要的缺失部分是如何获取反向计算图来计算梯度。
在纯 PyTorch 代码中,反向计算是通过对某个标量损失值调用 backward
函数来触发的。每个 PyTorch 函数在前向计算期间存储了反向所需的信息。
为了解释急切模式下反向期间发生了什么,有下面的实现,它模拟了 torch.cos()
函数的内置行为(需要一些关于如何在 PyTorch 中编写带有自动梯度支持的自定义函数的背景知识):
import torch
class Cosine(torch.autograd.Function):
@staticmethod
def forward(x0):
x1 = torch.cos(x0)
return x1, x0
@staticmethod
def setup_context(ctx, inputs, output):
x1, x0 = output
print(f"saving tensor of size {x0.shape}")
ctx.save_for_backward(x0)
@staticmethod
def backward(ctx, grad_output):
x0, = ctx.saved_tensors
result = (-torch.sin(x0)) * grad_output
return result
# Wrap Cosine in a function so that it is clearer what the output is
def cosine(x):
# `apply` will call `forward` and `setup_context`
y, x= Cosine.apply(x)
return y
def naive_two_cosine(x0):
x1 = cosine(x0)
x2 = cosine(x1)
return x2
在执行上述函数时,如果输入需要计算梯度,可以观察到有两个张量被保存下来:
input = torch.randn((5, 5, 5), requires_grad=True)
output = naive_two_cosine(input)
saving tensor of size torch.Size([5, 5, 5])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[7], line 2
1 input = torch.randn((5, 5, 5), requires_grad=True)
----> 2 output = naive_two_cosine(input)
Cell In[6], line 27, in naive_two_cosine(x0)
26 def naive_two_cosine(x0):
---> 27 x1 = cosine(x0)
28 x2 = cosine(x1)
29 return x2
Cell In[6], line 23, in cosine(x)
21 def cosine(x):
22 # `apply` will call `forward` and `setup_context`
---> 23 y, x= Cosine.apply(x)
24 return y
File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )
RuntimeError: A input that has been returned as-is as output is being saved for backward. This is not supported if you override setup_context. You should return and save a view of the input instead, e.g. with x.view_as(x) or setup ctx inside the forward function itself.