PyTorch编译器示例教程#

本教程旨在涵盖 PyTorch 编译器的以下几个方面:

  • 基本概念(即时(Just-In-Time)编译器,提前(Ahead-of-time)编译器)

  • Dynamo(图捕获,将用户的代码分为纯 Python 代码和纯 PyTorch 相关代码)

  • AOTAutograd(从正向计算图中生成反向计算图)

  • Inductor/其他后端(给定计算图,如何在不同的设备上更快地运行它)

这些组件将根据不同的后端选项被调用:

  • 当只使用 Dynamo 时,使用 torch.compile(backend="eager")

  • 当使用 Dynamo 和 AOTAutograd 时,使用 torch.compile(backend="aot_eager")

  • 默认情况下,使用 torch.compile(backend="inductor"),这意味着同时使用 Dynamo、AOTAutograd 以及 PyTorch 内置的图优化后端 Inductor。

PyTorch 编译器是即时编译器#

首先需要了解的概念是,PyTorch 编译器是一种即时编译器(Just-In-Time)。那么,即时编译器是什么意思呢?来看例子:

import torch

class A(torch.nn.Module):
    def forward(self, x):
        return torch.exp(2 * x)

class B(torch.nn.Module):
    def forward(self, x):
        return torch.exp(-x)

def f(x, mod):
    y = mod(x)
    z = torch.log(y)
    return z

# users might use
# mod = A()
# x = torch.randn(5, 5, 5)
# output = f(x, mod)

编写了函数 f,它包含模块调用,该调用将执行 mod.forward,以及 torch.log 调用。由于众所周知的代数简化恒等式 \(\log(\exp(a\times x))=a\times x\),迫不及待地想要优化代码如下:

def f(x, mod):
    if isinstance(mod, A):
        return 2 * x
    elif isinstance(mod, B):
        return -x

可以将其称为我们的第一个编译器,尽管它是由我们的大脑而不是自动化程序编译的。

如果希望更加严谨,那么编译器示例应该更新如下:

def f(x, mod):
    if isinstance(x, torch.Tensor) and isinstance(mod, A):
        return 2 * x
    elif isinstance(x, torch.Tensor) and isinstance(mod, B):
        return -x
    else:
        y = mod(x)
        z = torch.log(y)
        return z

需要检查每个参数,以确保优化条件是合理的,如果未能优化代码,还需要回退到原始代码。

这引出了即时编译器中的两个基本概念:守卫和转换代码。守卫是函数可以被优化的条件,而 转换代码 则是在满足守卫条件下的函数优化版本。在上面简单的编译器示例中,isinstance(mod, A) 就是守卫,而 return 2 * x 则是相应的转换代码,它在守卫条件下与原始代码等效,但执行速度要快得多。

上述例子是提前编译的编译器:检查所有可用的源代码,并在运行任何函数(即提前)之前,根据所有可能的守卫和转换代码编写优化后的函数。

另一类编译器是即时编译器:就在函数执行之前,它会分析是否可以对执行进行优化,以及在什么条件下可以对函数执行进行优化。希望这个条件足够通用,以适应新的输入,从而使即时编译的好处大于成本。如果所有条件都失败,它将尝试在新的条件下优化代码。

即时编译器的基本工作流程应该如下所示:

def f(x, mod):
    for guard, transformed_code in f.compiled_entries:
        if guard(x, mod):
            return transformed_code(x, mod)
    try:
        guard, transformed_code = compile_and_optimize(x, mod)
        f.compiled_entries.append([guard, transformed_code])
        return transformed_code(x, mod)
    except FailToCompileError:
        y = mod(x)
        z = torch.log(y)
        return z

即时编译器(Just-In-Time Compiler)仅针对其已经观察到的情况进行优化。每当它遇到新的输入,而这个输入不满足任何现有的保护条件时,它就会为这个新输入编译出新的保护条件和转换后的代码。

逐步解释编译器的状态(就保护条件和转换后的代码而言):

import torch

class A(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.exp(2 * x)

class B(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.exp(-x)

@just_in_time_compile # an imaginary compiler function
def f(x, mod):
    y = mod(x)
    z = torch.log(y)
    return z

a = A()
b = B()
x = torch.randn((5, 5, 5))

# before executing f(x, a), f.compiled_entries == [] is empty.
f(x, a)
# after executing f(x, a), f.compiled_entries == [Guard("isinstance(x, torch.Tensor) and isinstance(mod, A)"), TransformedCode("return 2 * x")]

# the second call of f(x, a) hit a condition, so we can just execute the transformed code
f(x, a)

# f(x, b) will trigger compilation and add a new compiled entry
# before executing f(x, b), f.compiled_entries == [Guard("isinstance(x, torch.Tensor) and isinstance(mod, A)"), TransformedCode("return 2 * x")]
f(x, b)
# after executing f(x, b), f.compiled_entries == [Guard("isinstance(x, torch.Tensor) and isinstance(mod, A)"), TransformedCode("return 2 * x"), Guard("isinstance(x, torch.Tensor) and isinstance(mod, B)"), TransformedCode("return -x")]

# the second call of f(x, b) hit a condition, so we can just execute the transformed code
f(x, b)

在这个示例中,我们对类类型进行防护检查,例如使用 isinstance(mod, A) 语句,而且转换后的代码仍然是 Python 代码;对于 torch.compile 来说,它需要对更多的条件进行防护,比如设备(CPU/GPU)、数据类型(int32, float32)、形状([10], [8]),而它的转换代码则是 Python 字节码。我们可以从函数中提取这些编译条目,更多细节请参阅 PyTorch 文档。尽管在防护和转换代码方面有所不同,但 torch.compile 的基本工作流程与本例相同,即它充当即时编译器。

超越代数简化的优化#

上述例子是关于代数简化的。然而,这样的优化在实践中相当罕见。让我们来看更实际的例子,并了解 PyTorch 编译器是如何对以下代码进行优化的:

import torch

@torch.compile
def function(inputs):
    x = inputs["x"]
    y = inputs["y"]
    x = x.cos().cos()
    if x.mean() > 0.5:
        x = x / 1.1
    return x * y

shape_10_inputs = {"x": torch.randn(10, requires_grad=True), "y": torch.randn(10, requires_grad=True)}
shape_8_inputs = {"x": torch.randn(8, requires_grad=True), "y": torch.randn(8, requires_grad=True)}
# warmup
for i in range(100):
    output = function(shape_10_inputs)
    output = function(shape_8_inputs)

# execution of compiled functions
output = function(shape_10_inputs)
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()

代码尝试实现 \(\text{cos}(\text{cos}(x))\) 激活函数,并根据其激活值调整输出的大小,然后将输出与另一个张量 y 相乘。

Dynamo是如何转换和修改函数功能的?#

当理解 torch.compile() 作为即时编译器的整体图景后,可以更深入地探究其工作原理。与 gccllvm 这样的通用编译器不同,torch.compile() 是特定领域的编译器:它只专注于 PyTorch 相关的计算图。因此,需要工具来将用户的代码分为两部分:纯 Python 代码和计算图代码。

Dynamo 就位于 torch._dynamo 模块内,是完成此任务的工具。通常不直接与这个模块交互。它是在 torch.compile() 函数内部被调用的。

从概念上讲,Dynamo 执行以下操作:

  • 找到第一个无法在计算图中表示但需要计算图中计算值的算子(例如,打印张量的值,使用张量的值来决定 Python 中的 if 语句控制流)。

  • 将前面的算子分成两部分:一个是纯粹关于张量计算的计算图,另一个是一些关于操纵 Python 对象的 Python 代码。

  • 将剩余的算子保留为一两个新函数(称为 resume 函数),并再次触发上述分析。

为了能够对函数进行这种细粒度的操作,Dynamo 在低于 Python 源代码级别的 Python 字节码层面运作。

以下过程描述了 Dynamo 对函数所做的处理。

Dynamo 的显著特性是它能够分析函数内部调用的所有函数。如果函数可以完全用计算图表示,那么这个函数的调用将被内联,从而消除该函数调用。

Dynamo 的使命是以安全稳妥的方式从 Python 代码中提取计算图。一旦获得了计算图,就可以进入计算图优化的世界。

备注

上述工作流程包含许多难以理解的字节码。对于那些无法阅读 Python 字节码的人来说,depyf 可以提供帮助!

动态形状支持来自 Dynamo#

深度学习编译器通常倾向于静态形状输入。这就是为什么上述保护条件包括形状保护的原因。第一次函数调用使用形状 [10] 的输入,但第二次函数调用使用的是形状 [8] 的输入。这将无法通过形状保护,因此触发新的代码转换。

默认情况下,Dynamo 支持动态形状。当形状保护失败时,它会分析和比较形状,并尝试将形状泛化。在这种情况下,看到形状为 [8] 的输入后,它将尝试泛化为任意一维形状 [s0],这被称为动态形状或符号形状。

AOTAutograd:从前向图生成反向计算图#

上述代码仅处理前向计算图。重要的缺失部分是如何获取反向计算图来计算梯度。

在纯 PyTorch 代码中,反向计算是通过对某个标量损失值调用 backward 函数来触发的。每个 PyTorch 函数在前向计算期间存储了反向所需的信息。

为了解释急切模式下反向期间发生了什么,有下面的实现,它模拟了 torch.cos() 函数的内置行为(需要一些关于如何在 PyTorch 中编写带有自动梯度支持的自定义函数的背景知识):

import torch
class Cosine(torch.autograd.Function):
    @staticmethod
    def forward(x0):
        x1 = torch.cos(x0)
        return x1, x0

    @staticmethod
    def setup_context(ctx, inputs, output):
        x1, x0 = output
        print(f"saving tensor of size {x0.shape}")
        ctx.save_for_backward(x0)

    @staticmethod
    def backward(ctx, grad_output):
        x0, = ctx.saved_tensors
        result = (-torch.sin(x0)) * grad_output
        return result

# Wrap Cosine in a function so that it is clearer what the output is
def cosine(x):
    # `apply` will call `forward` and `setup_context`
    y, x= Cosine.apply(x)
    return y

def naive_two_cosine(x0):
    x1 = cosine(x0)
    x2 = cosine(x1)
    return x2

在执行上述函数时,如果输入需要计算梯度,可以观察到有两个张量被保存下来:

input = torch.randn((5, 5, 5), requires_grad=True)
output = naive_two_cosine(input)
saving tensor of size torch.Size([5, 5, 5])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 2
      1 input = torch.randn((5, 5, 5), requires_grad=True)
----> 2 output = naive_two_cosine(input)

Cell In[6], line 27, in naive_two_cosine(x0)
     26 def naive_two_cosine(x0):
---> 27     x1 = cosine(x0)
     28     x2 = cosine(x1)
     29     return x2

Cell In[6], line 23, in cosine(x)
     21 def cosine(x):
     22     # `apply` will call `forward` and `setup_context`
---> 23     y, x= Cosine.apply(x)
     24     return y

File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         "In order to use an autograd.Function with functorch transforms "
    580         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    581         "staticmethod. For more details, please see "
    582         "https://pytorch.org/docs/main/notes/extending.func.html"
    583     )

RuntimeError: A input that has been returned as-is as output is being saved for backward. This is not supported if you override setup_context. You should return and save a view of the input instead, e.g. with x.view_as(x) or setup ctx inside the forward function itself.