Proxy/Retracing#

另一种操作 Graph 的方法是重用符号跟踪中使用的 Proxy 机制。例如,假设想要编写一个变换,将 PyTorch 函数分解为更小的运算。它将把每个 F.relu(x) 调用变换为 (x > 0) * x。一种可能是执行必要的 graph 重写,在 F.relu 之后插入比较和乘法,然后清理原来的 F.relu。但是,可以通过使用 Proxy 对象自动地将运算记录到 Graph 中来自动化这个过程。

要使用此方法,将希望插入的运算编写为常规 PyTorch 代码,并使用 Proxy 对象作为参数调用该代码。这些代理对象将捕获对它们执行的运算,并将它们附加到 Graph 中。

import torch
from torch import fx, nn
from torch.nn import functional as F

# 注意,这个分解(decomposition)规则可以理解为普通的 Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {F.relu: relu_decomposition}

def decompose(model: nn.Module,
              tracer_class : type = fx.Tracer) -> nn.Module:
    """
    将 `model` 分解为更小的复合运算。
    目前,它只支持将 ReLU 分解为它的数学定义:(x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = fx.proxy.GraphAppendingTracer(graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # 通过使用代理包装参数,可以分派到适当的分解规则,
            # 并通过符号跟踪隐式地将其添加到 Graph 中。
            proxy_args = [fx.Proxy(env[x.name], tracer) 
                          if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)
            
            # 对 `Proxy` 的运算总是产生新的 `Proxy`,分解规则的返回值也不例外。
            # 需要从 `Proxy` 中提取底层的 `Node`,以便在此变换的后续迭代中使用它。
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # 默认情况:没有此节点的分解规则,所以只需要将它复制到新的 Graph 中。
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式的 Graph 操作之外,使用 Proxy 还允许将重写规则指定为原生 Python 代码。对于需要大量重写规则的变换(如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。注意,在调用 Proxy 时,还传递了指向底层变量 graph 的跟踪器。如果 graph 中的操作是 n-ary 的(例如 add 是二进制算子),那么调用 Proxy 不会创建 graph 跟踪器的多个实例,这会导致意外的运行时错误。推荐这种使用 Proxy 的方法,特别是当底层算子不能被安全地假定为一元的时候。

如何使用代理对象创建计算图#

可以直接在原始节点周围创建代理对象。这可用于创建独立于符号跟踪的 Graph

下面的代码演示了如何使用带有原始节点的代理将运算附加到新 Graph。将创建两个参数( xy ),对这些参数执行一些运算,然后将创建的所有内容添加到新的 Graph中。然后将把这个 Graph 包装到 GraphModule 中。这样做会创建 Module 的可运行实例。

创建独立于符号跟踪的计算图

graph = fx.Graph()
tracer = fx.proxy.GraphAppendingTracer(graph)

创建输入节点:

raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

使用原始节点和图的默认跟踪器初始化代理

x = fx.Proxy(raw1, tracer)
y = fx.Proxy(raw2, tracer)

生成其他运算:

a = torch.cat([x, y])
b = torch.tanh(a)
c = torch.neg(b)
z = torch.add(b, c)

创建新的输出节点并将其添加到图中。通过这样做,图将包含刚刚创建的所有节点(因为它们都链接到输出节点).

graph.output(c.node)
output

将创建的图包装到 GraphModule 中,以获得最终的、可运行的 Module 的实例

mod = fx.GraphModule(nn.Module(), graph)
mod.graph.print_tabular()
opcode         name    target                                                   args         kwargs
-------------  ------  -------------------------------------------------------  -----------  --------
placeholder    x       x                                                        ()           {}
placeholder    y       y                                                        ()           {}
call_function  cat     <built-in method cat of type object at 0x7f8e82242200>   ([x, y],)    {}
call_function  tanh    <built-in method tanh of type object at 0x7f8e82242200>  (cat,)       {}
call_function  neg     <built-in method neg of type object at 0x7f8e82242200>   (tanh,)      {}
call_function  add     <built-in method add of type object at 0x7f8e82242200>   (tanh, neg)  {}
output         output  output                                                   (neg,)       {}