自定义跟踪器#
ModulePathTracer
#
将定义自定义的 Tracer
实例,对于每个记录的运算,也记下该运算起源于的模块的限定名。
from typing import Any, Callable
import torch
from torch import fx, nn
class ModulePathTracer(fx.Tracer):
"""
ModulePathTracer 是 FX 跟踪器,对于每个运算,它还记录了运算起源于的模块的限定名。
"""
# 正在跟踪的模块的当前限定名。
# 顶级模块由空字符串表示。
# 在进入 ``call_module`` 时更新,在退出 ``call_module`` 时恢复
current_module_qualified_name: str = ''
# 从 FX 节点到它起源模块的 qualname 的映射
# 这在记录运算时由 `create_proxy` 记录
node_to_originating_module: dict[fx.Node, str] = {}
def call_module(self, m: nn.Module,
forward: Callable[..., Any],
args: tuple[Any, ...],
kwargs: dict[str, Any]) -> Any:
"""
1. 存储调用者的限定名称以便稍后恢复
2. 在 `current_module_qualified_name` 中安装(install)调用者的限定名,
以供 `create_proxy` 检索。
3. 委托到正常的 Tracer.call_module 方法
4. 将调用者的限定名恢复到 current_module_qualified_name 中
"""
old_qualname = self.current_module_qualified_name
try:
self.current_module_qualified_name = self.path_of_module(m)
return super().call_module(m, forward, args, kwargs)
finally:
self.current_module_qualified_name = old_qualname
def create_proxy(self, kind: str,
target: fx.node.Target,
args: tuple[Any, ...],
kwargs: dict[str, Any],
name: str|None = None,
type_expr: Any|None = None):
"""覆写 `Tracer.create_proxy`。
该覆盖会截取每个运算的记录,
并将当前跟踪模块的限定名存储在 `node_to_originating_module` 中。
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
return proxy
# Testing: let's see how this works on a torchvision ResNet18 model
import torchvision.models as models
# Model under test
rn18 = models.resnet18()
# Instantiate our ModulePathTracer and use that to trace our ResNet18
tracer = ModulePathTracer()
traced_rn18 = tracer.trace(rn18)
# Print (node, module qualified name) for every node in the Graph
for node in traced_rn18.nodes:
module_qualname = tracer.node_to_originating_module.get(node)
print('Node', node, 'is from module', module_qualname)
Node x is from module
Node conv1 is from module conv1
Node bn1 is from module bn1
Node relu is from module relu
Node maxpool is from module maxpool
Node layer1_0_conv1 is from module layer1.0.conv1
Node layer1_0_bn1 is from module layer1.0.bn1
Node layer1_0_relu is from module layer1.0.relu
Node layer1_0_conv2 is from module layer1.0.conv2
Node layer1_0_bn2 is from module layer1.0.bn2
Node add is from module layer1.0
Node layer1_0_relu_1 is from module layer1.0.relu
Node layer1_1_conv1 is from module layer1.1.conv1
Node layer1_1_bn1 is from module layer1.1.bn1
Node layer1_1_relu is from module layer1.1.relu
Node layer1_1_conv2 is from module layer1.1.conv2
Node layer1_1_bn2 is from module layer1.1.bn2
Node add_1 is from module layer1.1
Node layer1_1_relu_1 is from module layer1.1.relu
Node layer2_0_conv1 is from module layer2.0.conv1
Node layer2_0_bn1 is from module layer2.0.bn1
Node layer2_0_relu is from module layer2.0.relu
Node layer2_0_conv2 is from module layer2.0.conv2
Node layer2_0_bn2 is from module layer2.0.bn2
Node layer2_0_downsample_0 is from module layer2.0.downsample.0
Node layer2_0_downsample_1 is from module layer2.0.downsample.1
Node add_2 is from module layer2.0
Node layer2_0_relu_1 is from module layer2.0.relu
Node layer2_1_conv1 is from module layer2.1.conv1
Node layer2_1_bn1 is from module layer2.1.bn1
Node layer2_1_relu is from module layer2.1.relu
Node layer2_1_conv2 is from module layer2.1.conv2
Node layer2_1_bn2 is from module layer2.1.bn2
Node add_3 is from module layer2.1
Node layer2_1_relu_1 is from module layer2.1.relu
Node layer3_0_conv1 is from module layer3.0.conv1
Node layer3_0_bn1 is from module layer3.0.bn1
Node layer3_0_relu is from module layer3.0.relu
Node layer3_0_conv2 is from module layer3.0.conv2
Node layer3_0_bn2 is from module layer3.0.bn2
Node layer3_0_downsample_0 is from module layer3.0.downsample.0
Node layer3_0_downsample_1 is from module layer3.0.downsample.1
Node add_4 is from module layer3.0
Node layer3_0_relu_1 is from module layer3.0.relu
Node layer3_1_conv1 is from module layer3.1.conv1
Node layer3_1_bn1 is from module layer3.1.bn1
Node layer3_1_relu is from module layer3.1.relu
Node layer3_1_conv2 is from module layer3.1.conv2
Node layer3_1_bn2 is from module layer3.1.bn2
Node add_5 is from module layer3.1
Node layer3_1_relu_1 is from module layer3.1.relu
Node layer4_0_conv1 is from module layer4.0.conv1
Node layer4_0_bn1 is from module layer4.0.bn1
Node layer4_0_relu is from module layer4.0.relu
Node layer4_0_conv2 is from module layer4.0.conv2
Node layer4_0_bn2 is from module layer4.0.bn2
Node layer4_0_downsample_0 is from module layer4.0.downsample.0
Node layer4_0_downsample_1 is from module layer4.0.downsample.1
Node add_6 is from module layer4.0
Node layer4_0_relu_1 is from module layer4.0.relu
Node layer4_1_conv1 is from module layer4.1.conv1
Node layer4_1_bn1 is from module layer4.1.bn1
Node layer4_1_relu is from module layer4.1.relu
Node layer4_1_conv2 is from module layer4.1.conv2
Node layer4_1_bn2 is from module layer4.1.bn2
Node add_7 is from module layer4.1
Node layer4_1_relu_1 is from module layer4.1.relu
Node avgpool is from module avgpool
Node flatten is from module
Node fc is from module fc
Node output is from module None
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
torch.has_cuda,
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
torch.has_cudnn,
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
torch.has_mps,
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
torch.has_mkldnn,
追踪全部的 ReLU
子模块#
在符号跟踪过程中,跟踪一些子模块并记录它们的组成运算;其他子模块在 IR 中显示为原子 “call_module” 节点。后一类中的模块称为“叶模块”。默认情况下,PyTorch 标准库(torch.nn
)中的所有模块都是叶模块。可以通过创建自定义跟踪器并重写 is_leaf_module
来改变这一点。
import torch
from torch import nn, fx
class M1(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(x)
default_traced: fx.GraphModule = fx.symbolic_trace(M1())
default_traced
M1(
(relu): ReLU()
)
default_traced.graph.print_tabular()
opcode name target args kwargs
----------- ------ -------- ------- --------
placeholder x x () {}
call_module relu relu (x,) {}
output output output (relu,) {}
更改 ReLU
的默认行为:
class LowerReluTracer(fx.Tracer):
def is_leaf_module(self, m: nn.Module, qualname: str):
if isinstance(m, torch.nn.ReLU):
return False
return super().is_leaf_module(m, qualname)
lower_relu_tracer = LowerReluTracer()
custom_traced_graph: fx.Graph = lower_relu_tracer.trace(M1())
custom_traced_graph.print_tabular()
opcode name target args kwargs
------------- ------ --------------------------------- ------- ------------------
placeholder x x () {}
call_function relu <function relu at 0x7f326b7685e0> (x,) {'inplace': False}
output output output (relu,) {}
为每个节点添加额外的属性#
在这里,将重写 create_node
,以便在创建每个 Node 时向其添加新属性
class M2(nn.Module):
def forward(self, a, b):
return a + b
class TaggingTracer(fx.Tracer):
def create_node(self, kind : str, target: str | Callable,
args: tuple[Any], kwargs: dict[str, Any], name: str | None=None,
type_expr: Any | None=None) -> fx.Node:
n = super().create_node(kind, target, args, kwargs, name)
n.tag = "foo"
return n
custom_traced_graph: fx.Graph = TaggingTracer().trace(M2())
def assert_all_nodes_have_tags(g: fx.Graph) -> bool:
for n in g.nodes:
if not hasattr(n, "tag") or not n.tag == "foo":
return False
return True
print(assert_all_nodes_have_tags(custom_traced_graph))
True
内联函数到现有的 Graph#
您可能希望内联函数的原因是避开 FX 的默认跟踪行为。例如,除非您已经定义了自定义跟踪器,否则 symbolic_trace
的开箱即用实现将导致引用 nn
模块实例的显式 call_module
调用,而不是被跟踪。假设这种行为几乎是你所需要的;唯一的问题是,您希望用函数的内联跟踪来替换单个模块调用。创建自定义跟踪器的工作量太大了。相反,您可以使用 代理 来完成此任务。
下面的代码演示了如何使用 Proxy
跟踪模块并将其内联到现有的 Graph
中。我们将跟踪 Graph,然后遍历它的节点,直到找到用内联跟踪替换 call_module
节点的正确位置。在这一点上,我们将从节点的 args
和 kwargs
创建代理。最后,我们将调用要用那些代理替换的函数——从本质上讲,这将“跟踪”该函数。最后,我们将把调用的结果插入到我们的 Graph
中。(最后一步将自动内联函数。)
class M(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(x) + 1.0
符号跟踪 M
实例。跟踪后, self.relu
被表示为 call_module
节点。
m = fx.symbolic_trace(M())
从 ReLU
graph 中插入节点,取代原来的调用 self.relu
.
创建指向原始 Graph 的图附加跟踪程序
tracer = fx.proxy.GraphAppendingTracer(m.graph)
for node in m.graph.nodes:
# 查找 `m` 中 `call_module` 节点对应的 `self.relu`。
# 这是我们想要替换为相同调用的内联版本的节点。
if (node.op, node.target) == ("call_module", "relu"):
with m.graph.inserting_before(node):
# 从每个节点当前的 args/kwargs 中的创建代理
proxy_args = fx.map_arg(node.args, lambda n: fx.Proxy(n, tracer))
proxy_kwargs = fx.map_arg(node.kwargs, lambda n: fx.Proxy(n, tracer))
# 使用 新创建的 Proxy 参数回调 `m.relu`
# `m.relu` 函数的通用版本;
# 通过从 `m` 中的节点创建的代理调用它,我们发出的节点引用 IR 中的现有值。
# 这个调用的结果是另一个 Proxy,我们可以将它挂钩到现有的 Graph 中,以完成函数内联。
proxy_output = m.relu(*proxy_args, **proxy_kwargs)
# 用函数的内联版本替换 relu `call_module` 节点
node.replace_all_uses_with(proxy_output.node)
# 确保旧的 relu 节点被擦除
m.graph.erase_node(node)
FX 计算 反函数#
import torch
from torch import fx
逆映射是接受函数 f(x)
并返回函数 g
使 f(g(x)) == x
的映射。例如,由于 log(exp(x)) == x
,所以 exp
和 log
是逆映射。
invert_mapping = {}
def add_inverse(a, b):
invert_mapping[a] = b
invert_mapping[b] = a
inverses = [
(torch.sin, torch.arcsin),
(torch.cos, torch.arccos),
(torch.tan, torch.arctan),
(torch.exp, torch.log),
]
for a, b in inverses:
add_inverse(a, b)
一般的策略是 backward walk graph,将每个节点变换为它的逆(inverse)节点。
为此,我们交换函数的输出和输入,然后在 invert_mapping
中查找它的逆函数。注意,此变换假设所有运算只接受一个输入并返回一个输出。
def invert(model: nn.Module) -> nn.Module:
fx_model = fx.symbolic_trace(model)
new_graph = fx.Graph() # 建立新的 graph
env = {}
for node in reversed(fx_model.graph.nodes):
if node.op == 'call_function':
# 在新 graph 中创建具有逆函数的节点,
# 并传递 `env[node.name]` (即之前的输出节点) 作为输入。
new_node = new_graph.call_function(invert_mapping[node.target],
(env[node.name],))
env[node.args[0].name] = new_node
elif node.op == 'output':
# 将 output 转换为输入 placeholder
new_node = new_graph.placeholder(node.name)
env[node.args[0].name] = new_node
elif node.op == 'placeholder':
# 将输入 placeholder 转换为 output
new_graph.output(env[node.name])
else:
raise RuntimeError("Not implemented")
new_graph.lint()
return fx.GraphModule(fx_model, new_graph)
def f(x):
return torch.exp(torch.tan(x))
res = invert(f)
print(res.code)
print(f(res((torch.arange(5) + 1))))
def forward(self, output):
log = torch.log(output); output = None
arctan = torch.arctan(log); log = None
return arctan
tensor([1., 2., 3., 4., 5.])