解释器模式

解释器模式#

FX 中一个有用的代码组织模式是循环遍历 Graph 中的所有 Node 并执行它们。这可以用于一些事情,包括对流经 Graph 的值的运行时分析,或者通过使用 Proxy 进行重跟踪的代码变换。

实例#

假设想要交换 torch.sigmoid()torch.neg() 运算顺序(包括它们的 Tensor 方法等量物)。可以像这样子类化 Interpreter

import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
from typing import Any
import torch
from torch import nn, fx


class NegSigmSwapInterpreter(fx.Interpreter):
    def call_function(self, target: fx.node.Target,
                      args: tuple, kwargs: dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(self, target: fx.node.Target,
                    args: tuple, kwargs: dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_function(target, args, kwargs)

def fn(x):
    return torch.sigmoid(x).neg()

gm = fx.symbolic_trace(fn)
inputs = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(inputs)
torch.testing.assert_close(result, 
                           torch.neg(inputs).sigmoid())

除了执行运算之外,还可以通过解释器提供 Proxy 值来生成新的 Graph

FX Transformer#

类似地,提供 Transformer 类(一种特殊类型的 Interpreter)来包含此模式。Transformer 的行为类似于 Interpreter,但不是调用 run 方法从模块中获取具体的输出值,而是调用 transform() 方法来返回新的 GraphModule,它服从于作为覆盖方法安装的任何变换规则。

class NegSigmSwapXformer(fx.Transformer):
    def call_function(self, target: 'Target', 
                      args: tuple[fx.node.Argument, ...], 
                      kwargs: dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target: 'Target', 
                    args: tuple[fx.node.Argument, ...], 
                    kwargs: dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = fx.symbolic_trace(fn)

transformed: nn.Module = NegSigmSwapXformer(gm).transform()
inputs = torch.randn(3, 4)
torch.testing.assert_close(transformed(inputs), 
                           torch.neg(inputs).sigmoid())

Shape 传播#

例如,假设想要运行 GraphModule 并记录 Tensor shape 和节点上的 dtype 属性,就像我们在运行时看到的那样。

Shape 传播。这个类接受 GraphModule。然后,使用给定的参数逐个节点地执行 GraphModulepropagate 方法。当每个运算执行时,ShapeProp 类存储每个运算的输出值 Node 的属性 shapedtype

正如您所看到的,完整的 FX 解释器(interpreter)并不复杂,但它可能非常有用。为了方便使用这种模式,提供了 Interpreter 类,它以一种可以通过方法重写来重写解释器执行的某些方面的方式包含了上述逻辑。

from torch.fx.passes.shape_prop import ShapeProp
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super().__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)

for node in gm.graph.nodes:
    print(node.name, node.meta['tensor_meta'].dtype,
        node.meta['tensor_meta'].shape)
x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])