调试#
通常在创作变换的过程中,我们的代码并不完全正确。在这种情况下,可能需要进行一些调试。关键是 backwards 工作:首先,检查调用生成的 module 的结果,以证明或否定正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的变换过程。
变换创作中的常见陷阱#
不确定的 set
迭代顺序。在 Python 中,设置的数据类型是无序的。例如,使用 set
来包含节点等对象的集合可能会导致意外的不确定性。一个例子是迭代一组节点,将它们插入到图中。因为设置的数据类型是无序的,输出程序中运算的顺序将是不确定的,并且可以在程序调用之间更改。推荐的替代方法是使用 dict
数据类型,这是 Python 3.7(以及 cPython 3.6)开始按照插入顺序排序。通过将要重复数据删除的值存储在 dict
的键中,dict
可以等价地用于 set
。
检查 module 的正确性#
因为大多数深度学习 module 的输出都是由浮点 torch.Tensor
实例组成,检查两个 torch.nn.Module
结果之间的等价性不像做简单的相等性检查那样直接。为了激发这个想法,举个例子(RuntimeError:有多个值的张量的布尔值不明确):
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb Cell 2 in <cell line: 21>()
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=16'>17</a> transformed_resnet18 = transform(resnet18)
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18'>19</a> input_image = torch.randn(5, 3, 224, 224)
---> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a> assert resnet18(input_image) == transformed_resnet18(input_image)
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
在这里,尝试用 ==
运算符检查两个深度学习模型的值是否相等。然而,由于运算符返回的是张量而不是 bool
值的问题,而且由于浮点值的比较应该使用误差边界(或 epsilon)来解释浮点运算的非交换性,这两个问题都没有很好地定义。可以使用 torch.allclose()
,它会考虑到相对和绝对公差阈值的近似比较:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
与参考实现相比,这是工具箱中检查变换模块行为是否如期望的那样的第一个工具。
调试生成的代码#
因为 FX 在 torch.fx.GraphModule
上生成 forward()
函数,所以使用传统的调试技术(如 print
语句或 pdb
)就不那么直接了。幸运的是,有几种技术可以用来调试生成的代码。
使用 pdb
#
调用 pdb
进入正在运行的程序。尽管表示 torch.fx.Graph
的代码不在任何源文件中,但是当调用 forward
传递时,仍然可以使用 pdb
手动进入它。
import torch
from torch import fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
--Return--
None
> /tmp/ipykernel_2297333/4158250709.py(21)<cell line: 21>()
19 # interactive `pdb` prompt. We can use the `step` or `s` command to
20 # step into the execution of the next line
---> 21 import pdb; pdb.set_trace()
22
23 my_module_transformed(input_value)
打印生成代码#
如果您想要多次运行相同的代码,那么使用 pdb
逐步找到正确的代码可能有点乏味。在这种情况下,一种方法是简单地将生成的 forward
传递复制粘贴到代码中,并从那里检查它。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 to_folder()
函数#
to_folder()
是 GraphModule
中的方法,它允许你将生成的 FX 代码转储到文件夹中。尽管像打印生成的代码那样,将 forward
传递复制到代码中通常就足够了,但是使用 to_folder()
检查模块和参数可能更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
在运行上面的示例之后,可以查看 foo/module.py
中的代码,并根据需要修改它(例如添加 print
语句或使用 pdb
),以调试生成的代码。
调试变换#
既然已经确定了变换正在创建不正确的代码,现在是调试变换本身的时候了。
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [#users=1] = placeholder[target=x]
%y : [#users=1] = placeholder[target=y]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面的实用函数,可以在应用变换之前和之后比较跟踪的 torch.nn.Module
。
抛开上面的例子,考虑下面的代码:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的例子,假设对 print(tracing)
的调用告诉我们变换中有一个错误。希望使用调试器找到哪里出了问题。可以通过中断 `transform_graph(已跟踪),然后按s“进入”对transform_graph(已跟踪)的调用来查看转换过程中发生了什么。