动态包装计算图输出#
下面的代码演示了如何根据运行时指定的参数更改现有的 Graph
。我们将让用户从预定义的 Enum 列表中指定激活函数,然后对其进行符号跟踪。接下来,我们将从图中的最后一个运算创建 Proxy
。我们将使用这个代理调用跟踪的激活函数,并将调用中的 output
节点插入到我们的图中。(最后一步将自动内联整个跟踪函数。)
from enum import Enum, auto
import torch
from torch import fx, nn
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
y = torch.cat([x, y])
return y
# 符号追踪 `M` 实例
traced = fx.symbolic_trace(M())
选择激活函数:
class ActivationFunction(Enum):
RELU = auto()
LEAKY_RELU = auto()
PRELU = auto()
将激活函数名称映射到它们的实现:
activation_functions = {
ActivationFunction.RELU: nn.ReLU(),
ActivationFunction.LEAKY_RELU: nn.LeakyReLU(),
ActivationFunction.PRELU: nn.PReLU(),
}
def wrap_in_activation_function(m: fx.GraphModule,
fn: ActivationFunction) -> fx.GraphModule:
# 获取输出节点
output_node: fx.Node|None = None
for n in reversed(m.graph.nodes):
if n.op == "output":
output_node = n
break
assert output_node
# 获取实际输出(输出节点的 "input")。
# 我们想要包装在用户指定的激活函数中的节点
assert len(output_node.all_input_nodes) == 1
wrap_node = output_node.all_input_nodes[0]
# 在 Proxy 中包装实际的输出
wrap_proxy = fx.Proxy(wrap_node)
# 获取指定激活函数的实现并以符号方式跟踪它
fn_impl = activation_functions[fn]
fn_impl_traced = fx.symbolic_trace(fn_impl)
# 使用 `output_op` 的代理包装器调用指定的激活函数。
# 这个调用的结果是另一个 Proxy,我们可以将它钩到现有的 Graph 中。
with m.graph.inserting_after(wrap_node):
fn_impl_output_node = fn_impl_traced(wrap_proxy)
new_args = (fn_impl_output_node.node,)
output_node.args = new_args
m.recompile()
测试:
x, y = torch.randn(5, 3), torch.randn(5, 3)
orig_output = traced(x, y)
wrap_in_activation_function(traced,
ActivationFunction.LEAKY_RELU)
new_output = traced(x, y)
torch.testing.assert_close(new_output,
torch.nn.LeakyReLU()(orig_output))