Graph
简介#
Graph
的语义可以在 Graph
文档中找到完整的处理方法,但是在这里只介绍基础知识。Graph
是一个数据结构,表示 GraphModule
上的方法。这需要的信息是:
此方法的输入是什么?
此方法当中执行了哪些运算?
此方法的输出是什么?
这三个概念都用 Node
实例表示。
用简短的例子来看看这是什么意思:
import torch
from torch import fx, nn
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.rand(3, 4))
self.linear = nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(),
dim=-1), 3)
m = MyModule()
gm = fx.symbolic_trace(m)
gm.graph.print_tabular()
opcode name target args kwargs
------------- ------------- ------------------------------------------------------- ------------------ -----------
placeholder x x () {}
get_attr linear_weight linear.weight () {}
call_function add <built-in function add> (x, linear_weight) {}
call_module linear linear (add,) {}
call_method relu relu (linear,) {}
call_function sum_1 <built-in method sum of type object at 0x7f0f55399aa0> (relu,) {'dim': -1}
call_function topk <built-in method topk of type object at 0x7f0f55399aa0> (sum_1, 3) {}
output output output (topk,) {}
print(gm.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%linear_weight : [num_users=1] = get_attr[target=linear.weight]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%relu : [num_users=1] = call_method[target=relu](args = (%linear,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu,), kwargs = {dim: -1})
%topk : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
return topk
这里定义了模块 MyModule
,用于演示,实例化它,象征性地跟踪它,然后调用 print_tabular()
方法打印出一个表,显示这个图的节点。
可以使用这些信息来回答上面提出的问题。
上述表格足以回答我们的三个问题:
这个方法的输入是什么?在 FX 中, 方法输入被表示为
placeholder
节点。在我们的例子中,只有一个placeholder
,可以推断出来我们的forward
的函数除了首参数self
外只有一个额外的输入(即x
)。这个方法当中执行了哪些运算?我们可以看到
get_attr
、call_funcation
、call_module
等节点表示了方法中的运算。这个方法的输出是什么?我们使用特别的
output
来表示Graph
的输出。