符号追踪的局限性#
FX 使用符号跟踪系统(又称 符号执行)以可变换/可分析的形式捕获程序的语义。
系统是追踪的(tracing),因为它执行程序(实际上是 Module
或函数)来记录运算。它是符号的(symbolic),因为在执行过程中流经程序的数据不是真正的数据,而是符号(FX 术语中的 Proxy
)。
尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。
动态流程控制#
符号追踪的主要限制是它目前不支持 动态控制流(dynamic control flow)。也就是说,循环或 if
语句的条件可能取决于程序的输入值。
比如:
import torch
from torch import fx
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
语句的条件依赖于 x.sum()
的值,而 x.sum()
依赖于函数输入 x
的值。因为 x
可以改变(例如,如果你将新的输入张量传递给追踪函数),这就是 动态控制流。回溯遍历代码,向您显示这种情况发生的位置。
静态流程控制#
另一方面,支持所谓的 静态控制流。静态控制流是循环或 if
语句,其值不能在调用之间更改。通常,在 PyTorch 程序中,这种控制流用于基于超参数对模型的体系结构做出决策的代码。举个具体的例子:
import torch
from torch import fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# 这个 if 语句就是所谓的静态控制流。
# 它的条件不依赖于任何输入值
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
def forward(self, x):
linear = self.linear(x); x = None
return linear
traced_with_activation = fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
def forward(self, x):
linear = self.linear(x); x = None
relu = torch.relu(linear); linear = None
return relu
if-语句 if
self.do_activation
不依赖于任何函数输入,因此它是静态的。do_activation
可以被认为是超参数,具有该参数不同值的 MyModule
的不同实例的追踪具有不同的代码。这是符号跟踪支持的有效模式。
许多动态控制流的实例在语义上是静态控制流。这些实例可以通过移除对输入值的数据依赖来支持符号跟踪,例如将值移动到 Module
属性,或者在符号跟踪期间将具体值绑定到参数:
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
非 torch
函数#
FX 使用 __torch_function__
作为拦截调用的机制(有关这方面的更多信息,请参阅技术概述)。一些函数,例如 Python 内置函数或数学模块中的函数,没有被 __torch_function__
覆盖,但仍然希望在符号跟踪中捕获它们。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
这个错误告诉我们不支持内置函数 len()
。可以使用 wrap()
API 将这样的函数作为直接调用记录在跟踪中:
fx.wrap('len')
fx.wrap('sqrt')
traced = fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer
自定义追踪#
Tracer
类是 symbolic_trace()
实现的基础类。跟踪的行为可以通过子类化 Tracer
来定制,如下所示:
class MyCustomTracer(torch.fx.Tracer):
"""自定义追踪器"""
...
# 使用自定义跟踪程序来跟踪整个 module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
# trace() 返回 Graph
traced_graph = MyCustomTracer().trace(mod)
# 包装到 GraphModule 中,使其可运行
traced = fx.GraphModule(mod, traced_graph)
叶模块#
叶模块(Leaf Module)是在符号跟踪中作为调用而不是被跟踪的模块。叶模块的默认集合是标准 torch.nn
模块实例。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
def forward(self, x):
linear = self.linear(x); x = None
neg = torch.neg(linear); linear = None
return neg
linear
被保留为调用,但是 submod
被跟踪。这是因为默认的“叶模块”包含了所有标准的 torch.nn
的模块。
叶模块集可以通过重写 is_leaf_module()
来定制。
Miscellanea#
Tensor 构造函数(torch.zeros()
, torch.ones()
, torch.rand()
, torch.randn()
, torch.sparse_coo_tensor()
)目前不可追踪。
可以使用确定性构造函数(
zeros
,ones
),它们产生的值将作为常量嵌入到跟踪中。只有当这些构造函数的参数引用动态输入大小时,才会出现问题。在这种情况下,ones_like()
或zeros_like()
可能是可行的替代方法。非确定性构造函数(
rand()
,randn()
)将在跟踪中嵌入单个随机值。这可能不是预期的行为。解决办法是 使用torch.fx.wrap()
包装。@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
类型注解
Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor)
是受支持的,并将通过符号跟踪保存。目前不支持函数中局部名称的注解。
在
training
flag 和子模块周围有问题当使用像
torch.nn.functional.dropout()
这样的函数时,训练参数通常被传递为self.training
。在 FX 跟踪过程中,这可能会作为常数值进行处理。
import torch
import torch.fx
class DropoutRepro(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.dropout(x, training=self.training)
traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
def forward(self, x):
dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None
return dropout
traced.eval()
x = torch.randn(5, 3)
torch.testing.assert_allclose(traced(x), x)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb Cell 18 in <cell line: 4>()
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb#X42sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> traced.eval()
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb#X42sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> x = torch.randn(5, 3)
----> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb#X42sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> torch.testing.assert_allclose(traced(x), x)
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/testing/_deprecated.py:32, in warn_deprecated.<locals>.outer_wrapper.<locals>.inner_wrapper(*args, **kwargs)
30 @functools.wraps(fn)
31 def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
---> 32 return_value = fn(*args, **kwargs)
33 tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
34 msg = (head + tail).strip()
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/testing/_deprecated.py:80, in assert_allclose(actual, expected, rtol, atol, equal_nan, msg)
77 if rtol is None and atol is None:
78 rtol, atol = _get_default_rtol_and_atol(actual, expected)
---> 80 torch.testing.assert_close(
81 actual,
82 expected,
83 rtol=rtol,
84 atol=atol,
85 equal_nan=equal_nan,
86 check_device=True,
87 check_dtype=False,
88 check_stride=False,
89 msg=msg or None,
90 )
[... skipping hidden 1 frame]
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/testing/_comparison.py:1095, in assert_equal(actual, expected, pair_types, sequence_types, mapping_types, msg, **options)
1092 return
1094 # TODO: compose all metas into one AssertionError
-> 1095 raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.709273338317871 at index (4, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
但是,当使用标准的 Dropout
子模块时,training
标志将被封装(因为保留了 Module
对象模型)且可以更改。
class DropoutRepro2(torch.nn.Module):
def __init__(self):
super().__init__()
self.drop = torch.nn.Dropout()
def forward(self, x):
return self.drop(x)
traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
traced.eval()
x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
def forward(self, x):
drop = self.drop(x); x = None
return drop
由于这种差异,可以考虑将与
training
标志动态交互的模块标记为叶模块。