使用 FX 构建简单的 CPU 性能分析器#
在本教程中,将使用 FX 完成以下任务:
捕获 PyTorch Python 代码,使其能够检查和收集关于代码结构和执行的统计信息。
构建一个小类,作为简单的性能“分析器”,从实际运行中收集关于模型每个部分的运行时统计信息。
在本教程中,将使用 torchvision ResNet18 模型进行演示。
import statistics, tabulate, time
from typing import Any, Dict, List
import torch
from torch import fx
from torchvision import models
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
想要更深入地检查它的性能。也就是说,对于下面的调用,模型的哪个部分花费的时间最长
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
回答这个问题的常用方法是遍历程序源代码,添加在程序中各个点上收集时间戳的代码,并比较这些时间戳之间的差异,以查看时间戳之间的区域需要消耗多长时间。
这种技术当然适用于 PyTorch 代码,但是如果不需要复制模型代码并编辑它就更好了,特别是还没有写过的代码(比如这个 torchvision 模型)。相反,将使用 FX 来自动化这个 “instrumentation” 过程,而不需要修改任何源代码。
备注
tabulate
是外部库,不是 PyTorch 的依赖项。使用它更容易地可视化性能数据。请确保您已经从您最喜欢的 Python 包源安装了它。
用符号跟踪捕获模型#
接下来,将使用 FX 的符号跟踪机制操作和检查的数据结构中捕获模型的定义。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [#users=1] = placeholder[target=x]
%conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [#users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [#users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [#users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [#users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [#users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [#users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [#users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [#users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [#users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [#users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [#users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [#users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [#users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [#users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [#users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [#users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [#users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [#users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [#users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [#users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [#users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [#users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [#users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [#users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [#users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [#users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [#users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [#users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [#users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [#users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [#users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [#users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [#users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [#users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [#users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [#users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [#users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [#users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [#users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [#users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [#users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [#users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [#users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [#users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [#users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [#users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [#users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [#users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [#users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
这为我们提供了 ResNet18 模型的计算图表示。Graph 由一系列相互连接的节点组成。每个 Node 表示 Python 代码中的一个调用站点(无论是函数、模块还是方法),而边(在每个节点上表示为 args
和 kwargs
)表示在这些调用站点之间传递的值。更多关于 Graph
表示和其他 FX API 的信息可以在 FX 文档 中找到。
创建性能分析解释器#
接下来,将创建继承自 Interpreter
的类。虽然 symbolic_trace
生成的 GraphModule
编译调用 GraphModule
时运行的 Python 代码,但运行 GraphModule
的另一种方法是逐个执行 ~torch.fx.Graph
中的每个节点。这就是 Interpreter
提供的功能:它逐节点(node-by-node)地解释 graph。
通过从 Interpreter
继承,可以覆盖各种功能并安装我们想要的分析行为。目标是拥有一个对象,我们可以向其传递模型,调用模型 1 次或更多次,然后获得关于模型和模型的每个部分在这些运行期间所花费的时间的统计数据。
class ProfilingInterpreter(fx.Interpreter):
def __init__(self, mod: torch.nn.Module):
# 将在构造函数中执行,而不是让用户 symbolically 地跟踪他们的模型。
# 因此,用户可以传入任何 ``Module``,而不必担心符号跟踪 API
gm = fx.symbolic_trace(mod)
super().__init__(gm)
# 要在这里储存两件东西:
#
# 1. ``mod`` 的总运行时列表。换句话说,存储了每次 ``mod(...)`` 调用这个解释器时所用的时间 。
self.total_runtime_sec: List[float] = []
# 2. 从 ``Node`` 到节点运行时间列表(以秒为单位)的映射。
# 这与 (1) 类似,但只是针对模型的特定子部分。
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# 接下来,让重写 ``run()`` 方法。
# ``Interpreter`` 的 ``run`` 方法是模型执行的顶层入口点。
# 我们想要拦截它,这样我们就可以记录模型的总运行时。
def run(self, *args) -> Any:
# 记录开始运行模型的时间
t_start = time.time()
# 通过将模型委托回 Interpreter.run() 来运行模型
return_val = super().run(*args)
# 记录完成运行模型的时间
t_end = time.time()
# 存储这个模型执行的总耗时
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# 现在,重写 ``run_node``。
# ``Interpreter`` 每次执行单个节点时调用 ``run_node``。
# 拦截它,以度量和记录模型中每个单独(individual)调用所花费的时间。
def run_node(self, n: torch.fx.Node) -> Any:
# 记录下开始运行 op 的时间
t_start = time.time()
# 通过将委托 Interpreter.run_node() 来运行 op
return_val = super().run_node(n)
# 记录完成 op 运行的时间
t_end = time.time()
# 如果在 runtimes_sec 数据结构中没有此节点,则添加一个列表值为空的项。
self.runtimes_sec.setdefault(n, [])
# 在 runtimes_sec 数据结构中记录这一次调用的总运行时间
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# 最后,将定义一个方法(一个不覆盖任何 ``Interpreter`` 方法的方法),
# 它为收集数据提供了一个很好的、有组织的视图。
def summary(self, should_sort : bool = False) -> str:
# 为每个节点建立汇总(summary)信息列表
node_summaries: List[List[Any]] = []
# 计算整个网络的平均运行时间。
# 因为在分析过程中可能多次调用网络,所以需要总结运行时。
# 选择使用算术平均值。
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# 对每个节点,记录汇总统计信息
for node, runtimes in self.runtimes_sec.items():
# 类似地,计算 ``node`` 的平均运行时
mean_runtime = statistics.mean(runtimes)
# 为了便于理解,还计算了每个节点相对于整个网络所花费的时间百分比。
pct_total = mean_runtime / mean_total_runtime * 100
# 记录节点类型、节点名称、平均运行时间和运行时间百分比
node_summaries.append([node.op, str(node),
mean_runtime, pct_total])
# 在进行性能分析时,要回答的最重要的问题之一是“哪个运算花费的时间最长?”。
# 通过在摘要视图中提供排序功能,可以使这一点更容易看到
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# 使用 ``tabulate`` 库创建格式良好的表来显示摘要信息
headers : List[str] = ['Op type', 'Op',
'Average runtime (s)',
'Pct total runtime']
return tabulate.tabulate(node_summaries, headers=headers)
备注
使用 Python 的 time.time
函数提取 clock 时间戳并进行比较。这不是衡量性能的最精确的方法,只能给我们一个一阶近似。我们使用这种简单的技术只是为了在本教程中进行演示。
ResNet18的性能研究#
现在可以使用 ProfilingInterpreter
来检查 ResNet18 模型的性能特征:
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
------------- --------------------- --------------------- -------------------
call_module conv1 0.00634789 10.7596
call_module maxpool 0.00393319 6.66672
call_module layer4_0_conv2 0.00305033 5.17027
call_module layer4_1_conv2 0.00285316 4.83607
call_module layer4_1_conv1 0.00277543 4.70433
call_module layer1_0_conv2 0.00272179 4.6134
call_module layer1_1_conv2 0.00269055 4.56046
call_module bn1 0.00231075 3.9167
call_module layer3_0_conv2 0.00221777 3.7591
call_module layer3_1_conv2 0.00215101 3.64594
call_module layer3_1_conv1 0.00209332 3.54815
call_module layer1_0_conv1 0.00208449 3.5332
call_module layer2_1_conv2 0.00204778 3.47096
call_module layer1_1_conv1 0.00199175 3.37599
call_module layer2_1_conv1 0.00197744 3.35175
call_module layer2_0_conv2 0.00197124 3.34124
call_module layer4_0_conv1 0.00188589 3.19657
call_module layer2_0_conv1 0.00169086 2.866
call_module layer3_0_conv1 0.00168705 2.85953
call_module layer2_0_downsample_0 0.000978947 1.6593
call_module layer3_0_downsample_0 0.000718355 1.2176
call_function add 0.000567913 0.962607
call_module layer1_1_bn2 0.000564814 0.957354
call_module layer4_0_downsample_0 0.000509262 0.863194
call_function add_1 0.000474691 0.804597
call_module relu 0.000412941 0.699931
call_module layer2_1_bn2 0.000388384 0.658307
call_function add_3 0.000330925 0.560915
call_module layer1_0_bn2 0.000254631 0.431597
call_module layer1_0_bn1 0.000241041 0.408562
call_module layer1_1_bn1 0.000214815 0.36411
call_module fc 0.000196218 0.332588
call_module layer4_1_bn2 0.000154018 0.26106
call_module layer4_0_bn2 0.00015378 0.260656
call_module avgpool 0.000152349 0.258231
call_module layer3_0_bn2 0.00014019 0.237621
call_module layer4_1_bn1 0.00013876 0.235196
call_function add_2 0.000137806 0.23358
call_module layer2_0_bn1 0.000136614 0.231559
call_module layer2_0_bn2 0.000135422 0.229539
call_module layer2_0_downsample_1 0.000133991 0.227114
call_module layer4_0_downsample_1 0.000130177 0.220648
call_module layer3_1_bn1 0.000128984 0.218627
call_module layer3_0_downsample_1 0.000124693 0.211353
call_function add_5 0.000124454 0.210949
call_module layer3_0_bn1 0.000123978 0.210141
call_module layer2_1_bn1 0.000123739 0.209737
call_module layer3_1_bn2 0.00012207 0.206908
call_module layer1_0_relu 0.000121832 0.206504
call_module layer4_0_bn1 0.000119448 0.202463
call_module layer1_1_relu 0.000116825 0.198017
call_module layer2_0_relu 0.000105381 0.17862
call_module layer2_1_relu 0.000101328 0.17175
call_function add_4 9.9659e-05 0.168921
call_module layer3_1_relu 9.63211e-05 0.163263
call_function add_6 9.63211e-05 0.163263
call_module layer3_0_relu 9.58443e-05 0.162455
call_module layer4_0_relu 9.53674e-05 0.161647
call_function add_7 9.27448e-05 0.157202
call_module layer1_0_relu_1 9.15527e-05 0.155181
call_module layer2_1_relu_1 9.05991e-05 0.153565
call_module layer4_1_relu 8.98838e-05 0.152352
call_module layer1_1_relu_1 8.82149e-05 0.149523
call_module layer2_0_relu_1 8.10623e-05 0.1374
call_module layer4_0_relu_1 7.51019e-05 0.127297
call_module layer3_0_relu_1 7.4625e-05 0.126489
call_module layer4_1_relu_1 7.43866e-05 0.126085
call_module layer3_1_relu_1 7.20024e-05 0.122043
call_function flatten 4.41074e-05 0.0747617
placeholder x 2.47955e-05 0.0420282
output output 1.74046e-05 0.0295006
小技巧
torch.nn.MaxPool2d
占用的时间最多(问题细节)。torch.nn.BatchNorm2d
也占用大量时间。可以进行 BN 融合 以提高性能。
record_function()
#
正常执行时正确记录 foo
范围
import torch
import torch.fx
from torch.autograd import profiler
# Setup: a module with `record_function`
class Foo(torch.nn.Module):
def forward(self, x):
with profiler.record_function('foo'):
return torch.relu(x)
f = Foo()
x = torch.randn(5, 3, 2)
with profiler.profile() as prof:
f(x)
print(prof)
------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::zeros 3.43% 17.000us 4.85% 24.000us 24.000us 1
aten::empty 1.41% 7.000us 1.41% 7.000us 7.000us 1
aten::zero_ 0.00% 0.000us 0.00% 0.000us 0.000us 1
foo 87.27% 432.000us 95.15% 471.000us 471.000us 1
aten::empty 0.61% 3.000us 0.61% 3.000us 3.000us 1
aten::relu 4.04% 20.000us 7.27% 36.000us 36.000us 1
aten::clamp_min 3.23% 16.000us 3.23% 16.000us 16.000us 1
------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 495.000us
FX 跟踪不记录 foo
范围
traced = fx.symbolic_trace(f)
with profiler.profile() as prof:
traced(x)
print(prof)
------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::relu 36.84% 7.000us 100.00% 19.000us 19.000us 1
aten::clamp_min 63.16% 12.000us 63.16% 12.000us 12.000us 1
------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 19.000us
自定义追踪器:
class ProfilerTracer(fx.Tracer):
def trace(self, root, concrete_args=None):
orig_record_function_enter = profiler.record_function.__enter__
orig_record_function_exit = profiler.record_function.__exit__
def fake_profiler_enter(_self):
nonlocal self
handle_proxy = self.create_proxy(
kind='call_function',
target=torch.ops.profiler._record_function_enter,
args=(_self.name,),
kwargs={})
assert getattr(_self, '_fx_profiler_ctx', None) is None
setattr(_self, '_fx_profiler_ctx', handle_proxy)
return handle_proxy
def fake_profiler_exit(_self, exc_type, exc_value, traceback):
assert hasattr(_self, '_fx_profiler_ctx')
handle_proxy = _self._fx_profiler_ctx
torch.ops.profiler._record_function_exit(handle_proxy)
setattr(_self, '_fx_profiler_ctx', None)
profiler.record_function.__enter__ = fake_profiler_enter
profiler.record_function.__exit__ = fake_profiler_exit
try:
return super().trace(root, concrete_args)
finally:
profiler.record_function.__enter__ = orig_record_function_enter
profiler.record_function.__exit__ = orig_record_function_exit
pt = ProfilerTracer()
graph_with_profiler = pt.trace(f)
traced_with_profiler = fx.GraphModule(pt.root, graph_with_profiler)
with profiler.profile() as prof:
traced_with_profiler(x)
print(prof)
------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------- ------------ ------------ ------------ ------------ ------------ ------------
foo 95.05% 307.000us 100.00% 323.000us 323.000us 1
aten::empty 1.24% 4.000us 1.24% 4.000us 4.000us 1
aten::relu 1.24% 4.000us 3.72% 12.000us 12.000us 1
aten::clamp_min 2.48% 8.000us 2.48% 8.000us 8.000us 1
------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 323.000us