Proxy/Retracing#
另一种操作 Graph
的方法是重用符号跟踪中使用的 Proxy
机制。例如,假设想要编写一个变换,将 PyTorch 函数分解为更小的运算。它将把每个 F.relu(x)
调用变换为 (x > 0) * x
。一种可能是执行必要的 graph 重写,在 F.relu
之后插入比较和乘法,然后清理原来的 F.relu
。但是,可以通过使用 Proxy
对象自动地将运算记录到 Graph
中来自动化这个过程。
要使用此方法,将希望插入的运算编写为常规 PyTorch 代码,并使用 Proxy
对象作为参数调用该代码。这些代理对象将捕获对它们执行的运算,并将它们附加到 Graph
中。
import torch
from torch import fx, nn
from torch.nn import functional as F
# 注意,这个分解(decomposition)规则可以理解为普通的 Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {F.relu: relu_decomposition}
def decompose(model: nn.Module,
tracer_class : type = fx.Tracer) -> nn.Module:
"""
将 `model` 分解为更小的复合运算。
目前,它只支持将 ReLU 分解为它的数学定义:(x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = fx.proxy.GraphAppendingTracer(graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# 通过使用代理包装参数,可以分派到适当的分解规则,
# 并通过符号跟踪隐式地将其添加到 Graph 中。
proxy_args = [fx.Proxy(env[x.name], tracer)
if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# 对 `Proxy` 的运算总是产生新的 `Proxy`,分解规则的返回值也不例外。
# 需要从 `Proxy` 中提取底层的 `Node`,以便在此变换的后续迭代中使用它。
new_node = output_proxy.node
env[node.name] = new_node
else:
# 默认情况:没有此节点的分解规则,所以只需要将它复制到新的 Graph 中。
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式的 Graph
操作之外,使用 Proxy
还允许将重写规则指定为原生 Python 代码。对于需要大量重写规则的变换(如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。注意,在调用 Proxy
时,还传递了指向底层变量 graph 的跟踪器。如果 graph 中的操作是 n-ary 的(例如 add 是二进制算子),那么调用 Proxy
不会创建 graph 跟踪器的多个实例,这会导致意外的运行时错误。推荐这种使用 Proxy
的方法,特别是当底层算子不能被安全地假定为一元的时候。
如何使用代理对象创建计算图#
可以直接在原始节点周围创建代理对象。这可用于创建独立于符号跟踪的 Graph
。
下面的代码演示了如何使用带有原始节点的代理将运算附加到新 Graph
。将创建两个参数( x
和 y
),对这些参数执行一些运算,然后将创建的所有内容添加到新的 Graph
中。然后将把这个 Graph
包装到 GraphModule
中。这样做会创建 Module
的可运行实例。
创建独立于符号跟踪的计算图
graph = fx.Graph()
tracer = fx.proxy.GraphAppendingTracer(graph)
创建输入节点:
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')
使用原始节点和图的默认跟踪器初始化代理
x = fx.Proxy(raw1, tracer)
y = fx.Proxy(raw2, tracer)
生成其他运算:
a = torch.cat([x, y])
b = torch.tanh(a)
c = torch.neg(b)
z = torch.add(b, c)
创建新的输出节点并将其添加到图中。通过这样做,图将包含刚刚创建的所有节点(因为它们都链接到输出节点).
graph.output(c.node)
output
将创建的图包装到 GraphModule
中,以获得最终的、可运行的 Module
的实例
mod = fx.GraphModule(nn.Module(), graph)
mod.graph.print_tabular()
opcode name target args kwargs
------------- ------ ------------------------------------------------------- ----------- --------
placeholder x x () {}
placeholder y y () {}
call_function cat <built-in method cat of type object at 0x7f8e82242200> ([x, y],) {}
call_function tanh <built-in method tanh of type object at 0x7f8e82242200> (cat,) {}
call_function neg <built-in method neg of type object at 0x7f8e82242200> (tanh,) {}
call_function add <built-in method add of type object at 0x7f8e82242200> (tanh, neg) {}
output output output (neg,) {}