动态包装计算图输出

动态包装计算图输出#

下面的代码演示了如何根据运行时指定的参数更改现有的 Graph。我们将让用户从预定义的 Enum 列表中指定激活函数,然后对其进行符号跟踪。接下来,我们将从图中的最后一个运算创建 Proxy。我们将使用这个代理调用跟踪的激活函数,并将调用中的 output 节点插入到我们的图中。(最后一步将自动内联整个跟踪函数。)

from enum import Enum, auto
import torch
from torch import fx, nn

class M(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        y = torch.cat([x, y])
        return y

# 符号追踪 `M` 实例
traced = fx.symbolic_trace(M())

选择激活函数:

class ActivationFunction(Enum):
    RELU = auto()
    LEAKY_RELU = auto()
    PRELU = auto()

将激活函数名称映射到它们的实现:

activation_functions = {
    ActivationFunction.RELU: nn.ReLU(),
    ActivationFunction.LEAKY_RELU: nn.LeakyReLU(),
    ActivationFunction.PRELU: nn.PReLU(),
}

def wrap_in_activation_function(m: fx.GraphModule, 
                                fn: ActivationFunction) -> fx.GraphModule:
    # 获取输出节点
    output_node: fx.Node|None = None
    for n in reversed(m.graph.nodes):
        if n.op == "output":
            output_node = n
            break
    assert output_node
    
    # 获取实际输出(输出节点的 "input")。
    # 我们想要包装在用户指定的激活函数中的节点
    assert len(output_node.all_input_nodes) == 1
    wrap_node = output_node.all_input_nodes[0]

    # 在 Proxy 中包装实际的输出
    wrap_proxy = fx.Proxy(wrap_node)
    
    # 获取指定激活函数的实现并以符号方式跟踪它
    fn_impl = activation_functions[fn]
    fn_impl_traced = fx.symbolic_trace(fn_impl)
    
    # 使用 `output_op` 的代理包装器调用指定的激活函数。
    # 这个调用的结果是另一个 Proxy,我们可以将它钩到现有的 Graph 中。
    with m.graph.inserting_after(wrap_node):
        fn_impl_output_node = fn_impl_traced(wrap_proxy)
        new_args = (fn_impl_output_node.node,)
        output_node.args = new_args

    m.recompile()

测试:

x, y = torch.randn(5, 3), torch.randn(5, 3)
orig_output = traced(x, y)

wrap_in_activation_function(traced, 
                            ActivationFunction.LEAKY_RELU)
new_output = traced(x, y)
torch.testing.assert_close(new_output, 
                           torch.nn.LeakyReLU()(orig_output))