追踪 Graph

追踪 Graph#

构建新 Graph 的一种方法是直接操控旧图。为了帮助实现这一点,可以简单地从符号跟踪中获取 Graph 并对其进行修改。例如,假设希望用 torch.mul() 调用替换 torch.add() 调用。

import torch
from torch import fx, nn

# 样例模块
class M(nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

查看节点信息:

m = M()
gm: fx.GraphModule = fx.symbolic_trace(m)
for node in gm.graph.nodes:
    print(node, node.op, node.target)
x placeholder x
y placeholder y
add call_function <built-in method add of type object at 0x7fd8cd04a200>
output output output
tracer = fx.Tracer()
graph: fx.Graph = tracer.trace(m)
# FX 将其 Graph 表示为节点的有序列表,因此可以遍历它们。
for node in graph.nodes:
    # 检查是否正在调用函数(例如:torch.add)
    if node.op == 'call_function':
        # target 属性是 call_function 调用的函数。
        if node.target == torch.add:
            node.target = torch.mul
graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
gm = fx.GraphModule(m, graph)

或者可以这样:

m = M()
traced: fx.GraphModule = fx.symbolic_trace(m)
for node in traced.graph.nodes:
    if node.op == 'call_function':
        # target 属性是 call_function 调用的函数。
        if node.target == torch.add:
            node.target = torch.mul
traced.graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
traced.recompile()
traced.graph.print_tabular()
opcode         name    target                                                  args    kwargs
-------------  ------  ------------------------------------------------------  ------  --------
placeholder    x       x                                                       ()      {}
placeholder    y       y                                                       ()      {}
call_function  add     <built-in method mul of type object at 0x7fd8cd04a200>  (x, y)  {}
output         output  output                                                  (add,)  {}

简单的验证:

x = torch.tensor([2])
y = torch.tensor([3])
m(x, y)
tensor([5])
traced(x, y)
tensor([6])

还可以进行更复杂的 Graph 重写,比如删除或追加节点。为了帮助完成这些变换,FX 提供了变换 Graph 的实用函数。下面是使用这些 API 附加 relu() 调用的示例。

def inserting_after(node, new_node=torch.relu):
    """指定插入点,并在此范围内添加到 Graph 中的任何节点都将插入到 `node` 之后"""
    with traced.graph.inserting_after(node):
        # 插入新的 `call_function` 节点调用 `torch.relu``
        new_node = traced.graph.call_function(new_node, args=(node,))
         
        # 希望所有使用 `node` 值的节点后添加 `relu` 回调
        # 使用 `replace_all_uses_with` API 来做到这一点。
        node.replace_all_uses_with(new_node)