FX 算子替换#
遍历
GraphModule
的Graph
中的所有Node
。确定是否应该替换当前
Node
(建议:匹配节点的target
属性)。创建替换
Node
并将其添加到Graph
中。使用 FX 内置的
replace_all_uses_with()
替换当前Node
的所有使用。从
Graph
中删除旧Node
。在
GraphModule
上调用recompile
。这会更新生成的 Python 代码,以反射(reflect)新的 Graph 状态。
下面的代码演示了用按位 AND 替换任意加法实例的示例。
要检查 Graph
在运算替换期间的演变情况,可以在要检查的行之后添加语句 print(traced.graph)
。
或者,调用 traced.graph.print_tabular()
以查看表格格式的 IR。
import torch
from torch import fx
import operator
# module 样例
class M(torch.nn.Module):
def forward(self, x, y):
return x + y, torch.add(x, y), x.add(y)
以符号方式跟踪模块的实例:
traced = fx.symbolic_trace(M())
有几种不同的表示加法的方法:
patterns = set([operator.add, torch.add, "add"])
# 遍历 Graph 中全部节点
for n in traced.graph.nodes:
# 如果目标匹配其中一个模式
if any(n.target == pattern for pattern in patterns):
# 设置插入点,添加新节点,用新节点替换所有 `n` 的用法
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
# 移除 graph 中旧的节点
traced.graph.erase_node(n)
# 不用忘记 recompile!
new_code = traced.recompile()
print(new_code.src)
def forward(self, x, y):
bitwise_and = torch.bitwise_and(x, y)
bitwise_and_1 = torch.bitwise_and(x, y)
bitwise_and_2 = torch.bitwise_and(x, y); x = y = None
return (bitwise_and, bitwise_and_1, bitwise_and_2)