自定义跟踪器#

ModulePathTracer#

将定义自定义的 Tracer 实例,对于每个记录的运算,也记下该运算起源于的模块的限定名。

from typing import Any, Callable
import torch
from torch import fx, nn
class ModulePathTracer(fx.Tracer):
    """
    ModulePathTracer 是 FX 跟踪器,对于每个运算,它还记录了运算起源于的模块的限定名。
    """
    
    # 正在跟踪的模块的当前限定名。
    # 顶级模块由空字符串表示。
    # 在进入 ``call_module`` 时更新,在退出 ``call_module`` 时恢复
    current_module_qualified_name: str = ''
    # 从 FX 节点到它起源模块的 qualname 的映射
    # 这在记录运算时由 `create_proxy` 记录
    node_to_originating_module: dict[fx.Node, str] = {}

    def call_module(self, m: nn.Module, 
                    forward: Callable[..., Any],
                    args: tuple[Any, ...], 
                    kwargs: dict[str, Any]) -> Any:
        """
        1. 存储调用者的限定名称以便稍后恢复
        2. 在 `current_module_qualified_name` 中安装(install)调用者的限定名,
           以供 `create_proxy` 检索。
        3. 委托到正常的 Tracer.call_module 方法
        4. 将调用者的限定名恢复到 current_module_qualified_name 中
        """
        old_qualname = self.current_module_qualified_name
        try:
            self.current_module_qualified_name = self.path_of_module(m)
            return super().call_module(m, forward, args, kwargs)
        finally:
            self.current_module_qualified_name = old_qualname

    def create_proxy(self, kind: str, 
                     target: fx.node.Target, 
                     args: tuple[Any, ...],
                     kwargs: dict[str, Any], 
                     name: str|None = None, 
                     type_expr: Any|None = None):
        """覆写 `Tracer.create_proxy`。
        该覆盖会截取每个运算的记录,
        并将当前跟踪模块的限定名存储在 `node_to_originating_module` 中。
        """
        proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
        self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
        return proxy
# Testing: let's see how this works on a torchvision ResNet18 model
import torchvision.models as models

# Model under test
rn18 = models.resnet18()

# Instantiate our ModulePathTracer and use that to trace our ResNet18
tracer = ModulePathTracer()
traced_rn18 = tracer.trace(rn18)

# Print (node, module qualified name) for every node in the Graph
for node in traced_rn18.nodes:
    module_qualname = tracer.node_to_originating_module.get(node)
    print('Node', node, 'is from module', module_qualname)
Node x is from module 
Node conv1 is from module conv1
Node bn1 is from module bn1
Node relu is from module relu
Node maxpool is from module maxpool
Node layer1_0_conv1 is from module layer1.0.conv1
Node layer1_0_bn1 is from module layer1.0.bn1
Node layer1_0_relu is from module layer1.0.relu
Node layer1_0_conv2 is from module layer1.0.conv2
Node layer1_0_bn2 is from module layer1.0.bn2
Node add is from module layer1.0
Node layer1_0_relu_1 is from module layer1.0.relu
Node layer1_1_conv1 is from module layer1.1.conv1
Node layer1_1_bn1 is from module layer1.1.bn1
Node layer1_1_relu is from module layer1.1.relu
Node layer1_1_conv2 is from module layer1.1.conv2
Node layer1_1_bn2 is from module layer1.1.bn2
Node add_1 is from module layer1.1
Node layer1_1_relu_1 is from module layer1.1.relu
Node layer2_0_conv1 is from module layer2.0.conv1
Node layer2_0_bn1 is from module layer2.0.bn1
Node layer2_0_relu is from module layer2.0.relu
Node layer2_0_conv2 is from module layer2.0.conv2
Node layer2_0_bn2 is from module layer2.0.bn2
Node layer2_0_downsample_0 is from module layer2.0.downsample.0
Node layer2_0_downsample_1 is from module layer2.0.downsample.1
Node add_2 is from module layer2.0
Node layer2_0_relu_1 is from module layer2.0.relu
Node layer2_1_conv1 is from module layer2.1.conv1
Node layer2_1_bn1 is from module layer2.1.bn1
Node layer2_1_relu is from module layer2.1.relu
Node layer2_1_conv2 is from module layer2.1.conv2
Node layer2_1_bn2 is from module layer2.1.bn2
Node add_3 is from module layer2.1
Node layer2_1_relu_1 is from module layer2.1.relu
Node layer3_0_conv1 is from module layer3.0.conv1
Node layer3_0_bn1 is from module layer3.0.bn1
Node layer3_0_relu is from module layer3.0.relu
Node layer3_0_conv2 is from module layer3.0.conv2
Node layer3_0_bn2 is from module layer3.0.bn2
Node layer3_0_downsample_0 is from module layer3.0.downsample.0
Node layer3_0_downsample_1 is from module layer3.0.downsample.1
Node add_4 is from module layer3.0
Node layer3_0_relu_1 is from module layer3.0.relu
Node layer3_1_conv1 is from module layer3.1.conv1
Node layer3_1_bn1 is from module layer3.1.bn1
Node layer3_1_relu is from module layer3.1.relu
Node layer3_1_conv2 is from module layer3.1.conv2
Node layer3_1_bn2 is from module layer3.1.bn2
Node add_5 is from module layer3.1
Node layer3_1_relu_1 is from module layer3.1.relu
Node layer4_0_conv1 is from module layer4.0.conv1
Node layer4_0_bn1 is from module layer4.0.bn1
Node layer4_0_relu is from module layer4.0.relu
Node layer4_0_conv2 is from module layer4.0.conv2
Node layer4_0_bn2 is from module layer4.0.bn2
Node layer4_0_downsample_0 is from module layer4.0.downsample.0
Node layer4_0_downsample_1 is from module layer4.0.downsample.1
Node add_6 is from module layer4.0
Node layer4_0_relu_1 is from module layer4.0.relu
Node layer4_1_conv1 is from module layer4.1.conv1
Node layer4_1_bn1 is from module layer4.1.bn1
Node layer4_1_relu is from module layer4.1.relu
Node layer4_1_conv2 is from module layer4.1.conv2
Node layer4_1_bn2 is from module layer4.1.bn2
Node add_7 is from module layer4.1
Node layer4_1_relu_1 is from module layer4.1.relu
Node avgpool is from module avgpool
Node flatten is from module 
Node fc is from module fc
Node output is from module None
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,

追踪全部的 ReLU 子模块#

在符号跟踪过程中,跟踪一些子模块并记录它们的组成运算;其他子模块在 IR 中显示为原子 "call_module" 节点。后一类中的模块称为“叶模块”。默认情况下,PyTorch 标准库(torch.nn)中的所有模块都是叶模块。可以通过创建自定义跟踪器并重写 is_leaf_module 来改变这一点。

import torch
from torch import nn, fx


class M1(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(x)

default_traced: fx.GraphModule = fx.symbolic_trace(M1())
default_traced
M1(
  (relu): ReLU()
)
default_traced.graph.print_tabular()
opcode       name    target    args     kwargs
-----------  ------  --------  -------  --------
placeholder  x       x         ()       {}
call_module  relu    relu      (x,)     {}
output       output  output    (relu,)  {}

更改 ReLU 的默认行为:

class LowerReluTracer(fx.Tracer):
    def is_leaf_module(self, m: nn.Module, qualname: str):
        if isinstance(m, torch.nn.ReLU):
            return False
        return super().is_leaf_module(m, qualname)
lower_relu_tracer = LowerReluTracer()
custom_traced_graph: fx.Graph = lower_relu_tracer.trace(M1())
custom_traced_graph.print_tabular()
opcode         name    target                             args     kwargs
-------------  ------  ---------------------------------  -------  ------------------
placeholder    x       x                                  ()       {}
call_function  relu    <function relu at 0x7f326b7685e0>  (x,)     {'inplace': False}
output         output  output                             (relu,)  {}

为每个节点添加额外的属性#

在这里,将重写 create_node,以便在创建每个 Node 时向其添加新属性

class M2(nn.Module):
    def forward(self, a, b):
        return a + b

class TaggingTracer(fx.Tracer):
    def create_node(self, kind : str, target:  str | Callable,
                    args: tuple[Any], kwargs: dict[str, Any], name: str | None=None,
                    type_expr: Any | None=None) -> fx.Node:
        n = super().create_node(kind, target, args, kwargs, name)
        n.tag = "foo"
        return n

custom_traced_graph: fx.Graph = TaggingTracer().trace(M2())

def assert_all_nodes_have_tags(g: fx.Graph) -> bool:
    for n in g.nodes:
        if not hasattr(n, "tag") or not n.tag == "foo":
            return False
    return True

print(assert_all_nodes_have_tags(custom_traced_graph))
True

内联函数到现有的 Graph#

您可能希望内联函数的原因是避开 FX 的默认跟踪行为。例如,除非您已经定义了自定义跟踪器,否则 symbolic_trace 的开箱即用实现将导致引用 nn 模块实例的显式 call_module 调用,而不是被跟踪。假设这种行为几乎是你所需要的;唯一的问题是,您希望用函数的内联跟踪来替换单个模块调用。创建自定义跟踪器的工作量太大了。相反,您可以使用 代理 来完成此任务。

下面的代码演示了如何使用 Proxy 跟踪模块并将其内联到现有的 Graph 中。我们将跟踪 Graph,然后遍历它的节点,直到找到用内联跟踪替换 call_module 节点的正确位置。在这一点上,我们将从节点的 argskwargs 创建代理。最后,我们将调用要用那些代理替换的函数——从本质上讲,这将“跟踪”该函数。最后,我们将把调用的结果插入到我们的 Graph 中。(最后一步将自动内联函数。)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(x) + 1.0

符号跟踪 M 实例。跟踪后, self.relu 被表示为 call_module 节点。

m = fx.symbolic_trace(M())

ReLU graph 中插入节点,取代原来的调用 self.relu.

创建指向原始 Graph 的图附加跟踪程序

tracer = fx.proxy.GraphAppendingTracer(m.graph)
for node in m.graph.nodes:
    # 查找 `m` 中 `call_module` 节点对应的 `self.relu`。
    # 这是我们想要替换为相同调用的内联版本的节点。
    if (node.op, node.target) == ("call_module", "relu"):
        with m.graph.inserting_before(node):
            # 从每个节点当前的 args/kwargs 中的创建代理
            proxy_args = fx.map_arg(node.args, lambda n: fx.Proxy(n, tracer))
            proxy_kwargs = fx.map_arg(node.kwargs, lambda n: fx.Proxy(n, tracer))
            # 使用 新创建的 Proxy 参数回调 `m.relu`
            # `m.relu` 函数的通用版本;
            # 通过从 `m` 中的节点创建的代理调用它,我们发出的节点引用 IR 中的现有值。
            # 这个调用的结果是另一个 Proxy,我们可以将它挂钩到现有的 Graph 中,以完成函数内联。
            proxy_output = m.relu(*proxy_args, **proxy_kwargs)
            # 用函数的内联版本替换 relu `call_module` 节点
            node.replace_all_uses_with(proxy_output.node)
            # 确保旧的 relu 节点被擦除
            m.graph.erase_node(node)

FX 计算 反函数#

import torch
from torch import fx

逆映射是接受函数 f(x) 并返回函数 g 使 f(g(x)) == x 的映射。例如,由于 log(exp(x)) == x,所以 explog 是逆映射。

invert_mapping = {}

def add_inverse(a, b):
    invert_mapping[a] = b
    invert_mapping[b] = a
    
inverses = [
    (torch.sin, torch.arcsin),
    (torch.cos, torch.arccos),
    (torch.tan, torch.arctan),
    (torch.exp, torch.log),
]
for a, b in inverses:
    add_inverse(a, b)

一般的策略是 backward walk graph,将每个节点变换为它的逆(inverse)节点。

为此,我们交换函数的输出和输入,然后在 invert_mapping 中查找它的逆函数。注意,此变换假设所有运算只接受一个输入并返回一个输出。

def invert(model: nn.Module) -> nn.Module:
    fx_model = fx.symbolic_trace(model)
    new_graph = fx.Graph()  # 建立新的 graph
    env = {}
    for node in reversed(fx_model.graph.nodes):
        if node.op == 'call_function':
            # 在新 graph 中创建具有逆函数的节点,
            # 并传递 `env[node.name]` (即之前的输出节点) 作为输入。
            new_node = new_graph.call_function(invert_mapping[node.target], 
                                               (env[node.name],))
            env[node.args[0].name] = new_node
        elif node.op == 'output':
            # 将 output 转换为输入 placeholder
            new_node = new_graph.placeholder(node.name)
            env[node.args[0].name] = new_node
        elif node.op == 'placeholder':
            # 将输入 placeholder 转换为 output
            new_graph.output(env[node.name])
        else:
            raise RuntimeError("Not implemented")

    new_graph.lint()
    return fx.GraphModule(fx_model, new_graph)


def f(x):
    return torch.exp(torch.tan(x))

res = invert(f)
print(res.code)
print(f(res((torch.arange(5) + 1))))
def forward(self, output):
    log = torch.log(output);  output = None
    arctan = torch.arctan(log);  log = None
    return arctan
    
tensor([1., 2., 3., 4., 5.])