FX 流程控制#
先定义带有流程控制的抽象基类:
from abc import ABC, abstractclassmethod
from typing import Any, NamedTuple
import torch
from torch import Tensor, nn, fx
class MyModuleBase(nn.Module, ABC):
def forward(self, x):
matrx = self.get_mul_matrix()
if self.no_relu():
return torch.mm(x, matrx)
else:
return torch.relu(torch.mm(x, matrx))
def get_mul_matrix(self):
return self.param
@abstractclassmethod
def no_relu(self):
...
定义简单的条件:
class MyModuleParamShape(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.shape[0] < 10
不同条件的实例化:
mm_only_mod = MyModuleParamShape(in_channels=5)
relu_mod = MyModuleParamShape(in_channels=15)
验证一个模块只执行 mm
运算,而另一个模块在级联(cascade)中执行 mm
和 relu
运算。
验证仅仅执行 mm_only_mod
运算:
x = torch.randn(10, 5)
torch.testing.assert_close(mm_only_mod(x),
torch.mm(x, mm_only_mod.get_mul_matrix()))
验证计算图模块计算结果是相同:
tracer = fx.Tracer(param_shapes_constant=True)
traced_graph = tracer.trace(mm_only_mod)
graph_mod_mm = fx.GraphModule(mm_only_mod, traced_graph)
torch.testing.assert_close(graph_mod_mm(x),
torch.mm(x, mm_only_mod.get_mul_matrix()))
创建具有不同参数形状的新模块,以沿着不同的代码路径前进:
x = torch.randn(10, 15)
torch.testing.assert_close(relu_mod(x),
torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
tracer2 = fx.Tracer(param_shapes_constant=True)
traced_graph2 = tracer2.trace(relu_mod)
# 验证计算图模块计算结果是相同
graph_mod_relu = fx.GraphModule(relu_mod, traced_graph2)
torch.testing.assert_close(graph_mod_relu(x),
torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
第二个 graph 有额外的 relu
函数调用节点:
graph1_node_targets = [n.target for n in traced_graph.nodes]
graph2_node_targets = [n.target for n in traced_graph2.nodes]
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
将上述验证过程放入函数中以重用:
def verify_mm_relu_mods(mm_only_mod, relu_mod):
"""
验证一个模块只执行 `mm` 运算,
而另一个模块在级联(cascade)中执行 `mm` 和 `relu` 运算。
"""
x = torch.randn(10, 5)
torch.testing.assert_close(mm_only_mod(x),
torch.mm(x, mm_only_mod.get_mul_matrix()))
tracer = fx.Tracer(param_shapes_constant=True)
traced_graph = tracer.trace(mm_only_mod)
# 验证计算图模块计算结果是相同
graph_mod_mm = fx.GraphModule(mm_only_mod, traced_graph)
torch.testing.assert_close(graph_mod_mm(x),
torch.mm(x, mm_only_mod.get_mul_matrix()))
# 创建具有不同参数形状的新模块,以沿着不同的代码路径前进
x = torch.randn(10, 15)
torch.testing.assert_close(relu_mod(x),
torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
tracer2 = fx.Tracer(param_shapes_constant=True)
traced_graph2 = tracer2.trace(relu_mod)
# 验证计算图模块计算结果是相同
graph_mod_relu = fx.GraphModule(relu_mod, traced_graph2)
torch.testing.assert_close(graph_mod_relu(x),
torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
graph1_node_targets = [n.target for n in traced_graph.nodes]
graph2_node_targets = [n.target for n in traced_graph2.nodes]
# 第二个 graph 有额外的 `relu` 函数调用节点
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
class MyModuleParamSize(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.size()[0] < 10
class MyModuleParamDim(MyModuleBase):
def __init__(self, param):
super().__init__()
self.param = param
def get_mul_matrix(self):
return self.param[0] if (self.param.dim() == 3) else self.param
def no_relu(self):
return self.param.dim() == 3
class MyModuleParamNDim(MyModuleBase):
def __init__(self, param):
super().__init__()
self.param = param
def get_mul_matrix(self):
return self.param[0] if (self.param.ndim == 3) else self.param
def no_relu(self):
return self.param.ndim == 3
class MyModuleParamNumEl(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.numel() < 10 * 3
class MyModuleParamNElement(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.nelement() < 10 * 3
def test_param_size_const(self):
mymod = MyModuleParamSize(in_channels=5)
mymod2 = MyModuleParamSize(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_dim_const(self):
mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_ndim_const(self):
mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_numel_const(self):
mymod = MyModuleParamNumEl(in_channels=5)
mymod2 = MyModuleParamNumEl(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_nelement_const(self):
mymod = MyModuleParamNElement(in_channels=5)
mymod2 = MyModuleParamNElement(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)