追踪 Graph
#
构建新 Graph
的一种方法是直接操控旧图。为了帮助实现这一点,可以简单地从符号跟踪中获取 Graph
并对其进行修改。例如,假设希望用 torch.mul()
调用替换 torch.add()
调用。
import torch
from torch import fx, nn
# 样例模块
class M(nn.Module):
def forward(self, x, y):
return torch.add(x, y)
查看节点信息:
m = M()
gm: fx.GraphModule = fx.symbolic_trace(m)
for node in gm.graph.nodes:
print(node, node.op, node.target)
x placeholder x
y placeholder y
add call_function <built-in method add of type object at 0x7fd8cd04a200>
output output output
tracer = fx.Tracer()
graph: fx.Graph = tracer.trace(m)
# FX 将其 Graph 表示为节点的有序列表,因此可以遍历它们。
for node in graph.nodes:
# 检查是否正在调用函数(例如:torch.add)
if node.op == 'call_function':
# target 属性是 call_function 调用的函数。
if node.target == torch.add:
node.target = torch.mul
graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
gm = fx.GraphModule(m, graph)
或者可以这样:
m = M()
traced: fx.GraphModule = fx.symbolic_trace(m)
for node in traced.graph.nodes:
if node.op == 'call_function':
# target 属性是 call_function 调用的函数。
if node.target == torch.add:
node.target = torch.mul
traced.graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
traced.recompile()
traced.graph.print_tabular()
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in method mul of type object at 0x7fd8cd04a200> (x, y) {}
output output output (add,) {}
简单的验证:
x = torch.tensor([2])
y = torch.tensor([3])
m(x, y)
tensor([5])
traced(x, y)
tensor([6])
还可以进行更复杂的 Graph
重写,比如删除或追加节点。为了帮助完成这些变换,FX 提供了变换 Graph
的实用函数。下面是使用这些 API 附加 relu()
调用的示例。
def inserting_after(node, new_node=torch.relu):
"""指定插入点,并在此范围内添加到 Graph 中的任何节点都将插入到 `node` 之后"""
with traced.graph.inserting_after(node):
# 插入新的 `call_function` 节点调用 `torch.relu``
new_node = traced.graph.call_function(new_node, args=(node,))
# 希望所有使用 `node` 值的节点后添加 `relu` 回调
# 使用 `replace_all_uses_with` API 来做到这一点。
node.replace_all_uses_with(new_node)