Graph 简介

Graph 简介#

Graph 的语义可以在 Graph 文档中找到完整的处理方法,但是在这里只介绍基础知识。Graph 是一个数据结构,表示 GraphModule 上的方法。这需要的信息是:

  • 此方法的输入是什么?

  • 此方法当中执行了哪些运算?

  • 此方法的输出是什么?

这三个概念都用 Node 实例表示。

用简短的例子来看看这是什么意思:

import torch
from torch import fx, nn
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.rand(3, 4))
        self.linear = nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), 
                                    dim=-1), 3)

m = MyModule()
gm = fx.symbolic_trace(m)
gm.graph.print_tabular()
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_module    linear         linear                                                   (add,)              {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7f0f55399aa0>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x7f0f55399aa0>  (sum_1, 3)          {}
output         output         output                                                   (topk,)             {}
print(gm.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %relu : [num_users=1] = call_method[target=relu](args = (%linear,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu,), kwargs = {dim: -1})
    %topk : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk

这里定义了模块 MyModule,用于演示,实例化它,象征性地跟踪它,然后调用 print_tabular() 方法打印出一个表,显示这个图的节点。

可以使用这些信息来回答上面提出的问题。

上述表格足以回答我们的三个问题:

  1. 这个方法的输入是什么?在 FX 中, 方法输入被表示为 placeholder 节点。在我们的例子中,只有一个 placeholder,可以推断出来我们的 forward 的函数除了首参数 self 外只有一个额外的输入(即 x)。

  2. 这个方法当中执行了哪些运算?我们可以看到 get_attrcall_funcationcall_module 等节点表示了方法中的运算。

  3. 这个方法的输出是什么?我们使用特别的 output 来表示 Graph 的输出。

现在知道了方法是如何在 torch.fx 中被记录表示的, 下一步便是通过 Graph 修改它。

备注

Node 是表示 Graph 中各个运算的数据结构。在大多数情况下,Node 表示对各种实体的调用点,如算子、方法和模块(一些例外包括指定函数输入和输出的 Node)。每个 Node 都有一个由 op 属性指定的函数。