FX 简介

目录

FX 简介#

FX 是供开发人员用来变换 Module 实例的工具包。FX 由三个主要组件组成:符号跟踪器(symbolic tracer)、中间表示(intermediate representation,简写 IR)和 Python 代码生成(Python code generation)。

import torch
from torch import nn, fx


# 用于演示的简单模块
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.rand(3, 4))
        self.linear = nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

符号跟踪前端(Symbolic tracing frontend)

备注

符号跟踪器 执行 Python 代码的“符号执行”。它通过代码提供虚假的值,称为 代理。记录对这些代理的运算。有关符号跟踪的更多信息可以在 symbolic_trace()Tracer 文档中找到。

捕获模块的语义:

symbolic_traced: fx.GraphModule = fx.symbolic_trace(module)

高级中间表示(intermediate representationIR)

备注

中间表示 是符号跟踪期间记录的运算的容器。它由一组 node 组成,这些 node 表示函数输入、调用站点(callsites,即函数、方法或 Module 实例)和返回值。关于 IR 的更多信息可以在 Graph 的文档中找到。IR 是应用变换(transformations)的格式。

计算图(graph)表示:

print(symbolic_traced.graph)
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

代码生成(Code generation)

备注

Python 代码生成 使 FX 成为 Python 到 Python (或 Module-to-Module)的变换工具包。对于每个 Graph IR,可以创建与 Graph 语义匹配的有效 Python 代码。该功能封装在 GraphModule 中,它是 Module 实例,包含 Graph 以及从 Graph 生成的 forward() 方法。

有效的 Python 代码:

print(symbolic_traced.code)
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
    

总的来说,这个组件管道(symbolic tracing -> intermediate representation -> transforms -> Python code generation)构成了 FX 的 Python-to-Python 变换管道(pipeline)。此外,这些组件可以单独使用。例如,可以单独使用符号跟踪来捕获代码的形式,以便进行分析(而不是变换)。代码生成可以用于以编程方式生成模型,例如从配置文件生成模型。

编写变换#

什么是 FX 变换?本质上,它是这样的函数:

import torch
from torch import nn, fx

def transform(m: nn.Module,
              tracer_class: type = fx.Tracer) -> nn.Module:
    # 步骤 1:获取表示 `m` 代码的计算图表示

    # NOTE: fx.symbolic_trace 是对 fx.Tracer.trace 调用和构造 GraphModule 的包装器。
    # 将在变换中分离它,以允许调用者自定义 tracing 行为。
    graph: fx.Graph = tracer_class().trace(m)

    # 步骤 2: 修改此 Graph 或创建新的 Graph
    graph = ...

    # 步骤 3:返回构造的 Module
    return fx.GraphModule(m, graph)

transformation 函数需要 Module 作为输入, 然后从该 Module 获得 Graph (即 IR)对其进行修改, 然后返回新的 Module。你应该把返回的 Module 想成和正常的 Module 一样:你可以把它传递给另一个 FX 变换,你可以把它传递给 TorchScript,或者你可以运行它。确保 FX 变换的输入和输出是 Module 将允许可组合性。

备注

也可以修改现有的 GraphModule,而不是创建新的 GraphModule,如下所示:


def transform(m : nn.Module) -> nn.Module:
    gm : fx.GraphModule = fx.symbolic_trace(m)

    # 修改 gm.graph
    # <...>

    # 从 `gm` 的 Graph 中重新编译 forward() 方法
    gm.recompile()
    return gm

小技巧

注意,你必须调用 recompile() 来将 GraphModule 上生成的 forward() 方法与修改后的 Graph 同步。

假设您已经传入了被跟踪到 Graph 中的 Module,那么现在您可以采用两种主要方法来构建新的 Graph