解释器模式#
FX 中一个有用的代码组织模式是循环遍历 Graph
中的所有 Node
并执行它们。这可以用于一些事情,包括对流经 Graph
的值的运行时分析,或者通过使用 Proxy
进行重跟踪的代码变换。
实例#
假设想要交换 torch.sigmoid()
,torch.neg()
运算顺序(包括它们的 Tensor
方法等量物)。可以像这样子类化 Interpreter
:
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
from typing import Any
import torch
from torch import nn, fx
class NegSigmSwapInterpreter(fx.Interpreter):
def call_function(self, target: fx.node.Target,
args: tuple, kwargs: dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs)
def call_method(self, target: fx.node.Target,
args: tuple, kwargs: dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_function(target, args, kwargs)
def fn(x):
return torch.sigmoid(x).neg()
gm = fx.symbolic_trace(fn)
inputs = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(inputs)
torch.testing.assert_close(result,
torch.neg(inputs).sigmoid())
除了执行运算之外,还可以通过解释器提供 Proxy
值来生成新的 Graph
。
FX Transformer
#
类似地,提供 Transformer
类(一种特殊类型的 Interpreter
)来包含此模式。Transformer
的行为类似于 Interpreter
,但不是调用 run
方法从模块中获取具体的输出值,而是调用 transform()
方法来返回新的 GraphModule
,它服从于作为覆盖方法安装的任何变换规则。
class NegSigmSwapXformer(fx.Transformer):
def call_function(self, target: 'Target',
args: tuple[fx.node.Argument, ...],
kwargs: dict[str, Any]) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
def call_method(self, target: 'Target',
args: tuple[fx.node.Argument, ...],
kwargs: dict[str, Any]) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
def fn(x):
return torch.sigmoid(x).neg()
gm = fx.symbolic_trace(fn)
transformed: nn.Module = NegSigmSwapXformer(gm).transform()
inputs = torch.randn(3, 4)
torch.testing.assert_close(transformed(inputs),
torch.neg(inputs).sigmoid())
Shape 传播#
例如,假设想要运行 GraphModule
并记录 Tensor
shape 和节点上的 dtype 属性,就像我们在运行时看到的那样。
Shape 传播。这个类接受 GraphModule
。然后,使用给定的参数逐个节点地执行 GraphModule
的 propagate
方法。当每个运算执行时,ShapeProp
类存储每个运算的输出值 Node
的属性 shape
和 dtype
。
正如您所看到的,完整的 FX 解释器(interpreter)并不复杂,但它可能非常有用。为了方便使用这种模式,提供了 Interpreter
类,它以一种可以通过方法重写来重写解释器执行的某些方面的方式包含了上述逻辑。
from torch.fx.passes.shape_prop import ShapeProp
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])