FX 简介#
FX 是供开发人员用来变换 Module
实例的工具包。FX 由三个主要组件组成:符号跟踪器(symbolic tracer)、中间表示(intermediate representation,简写 IR)和 Python 代码生成(Python code generation)。
import torch
from torch import nn, fx
# 用于演示的简单模块
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.rand(3, 4))
self.linear = nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
符号跟踪前端(Symbolic tracing frontend)
备注
符号跟踪器 执行 Python 代码的“符号执行”。它通过代码提供虚假的值,称为 代理。记录对这些代理的运算。有关符号跟踪的更多信息可以在 symbolic_trace()
和 Tracer
文档中找到。
捕获模块的语义:
symbolic_traced: fx.GraphModule = fx.symbolic_trace(module)
高级中间表示(intermediate representationIR)
备注
中间表示 是符号跟踪期间记录的运算的容器。它由一组 node 组成,这些 node 表示函数输入、调用站点(callsites,即函数、方法或 Module
实例)和返回值。关于 IR 的更多信息可以在 Graph
的文档中找到。IR 是应用变换(transformations)的格式。
计算图(graph)表示:
print(symbolic_traced.graph)
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
代码生成(Code generation)
备注
Python 代码生成 使 FX 成为 Python 到 Python (或 Module-to-Module)的变换工具包。对于每个 Graph IR,可以创建与 Graph 语义匹配的有效 Python 代码。该功能封装在 GraphModule
中,它是 Module
实例,包含 Graph 以及从 Graph
生成的 forward()
方法。
有效的 Python 代码:
print(symbolic_traced.code)
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
总的来说,这个组件管道(symbolic tracing -> intermediate representation -> transforms -> Python code generation)构成了 FX 的 Python-to-Python 变换管道(pipeline)。此外,这些组件可以单独使用。例如,可以单独使用符号跟踪来捕获代码的形式,以便进行分析(而不是变换)。代码生成可以用于以编程方式生成模型,例如从配置文件生成模型。
编写变换#
什么是 FX 变换?本质上,它是这样的函数:
import torch
from torch import nn, fx
def transform(m: nn.Module,
tracer_class: type = fx.Tracer) -> nn.Module:
# 步骤 1:获取表示 `m` 代码的计算图表示
# NOTE: fx.symbolic_trace 是对 fx.Tracer.trace 调用和构造 GraphModule 的包装器。
# 将在变换中分离它,以允许调用者自定义 tracing 行为。
graph: fx.Graph = tracer_class().trace(m)
# 步骤 2: 修改此 Graph 或创建新的 Graph
graph = ...
# 步骤 3:返回构造的 Module
return fx.GraphModule(m, graph)
transformation
函数需要 Module
作为输入, 然后从该 Module
获得 Graph
(即 IR)对其进行修改, 然后返回新的 Module
。你应该把返回的 Module
想成和正常的 Module
一样:你可以把它传递给另一个 FX 变换,你可以把它传递给 TorchScript,或者你可以运行它。确保 FX 变换的输入和输出是 Module
将允许可组合性。
备注
也可以修改现有的 GraphModule
,而不是创建新的 GraphModule
,如下所示:
def transform(m : nn.Module) -> nn.Module:
gm : fx.GraphModule = fx.symbolic_trace(m)
# 修改 gm.graph
# <...>
# 从 `gm` 的 Graph 中重新编译 forward() 方法
gm.recompile()
return gm
小技巧
注意,你必须调用 recompile()
来将 GraphModule
上生成的 forward()
方法与修改后的 Graph
同步。