THOP: PyTorch OpCounter
导航
THOP: PyTorch OpCounter#
知识点#
FLOPS(Floating Point Operations Per Second):每秒浮点运算次数,是衡量硬件速度的指标。
FLOPs(Floating Point Operations):浮点运算次数,用来衡量模型计算复杂度,常用来做神经网络模型速度的间接衡量标准。FLOPS 与 FLOPs 常常被人们混淆使用。
MACs(Multiply–Accumulate Operations):乘加累积运算数(
a <- a + (b x c)
),常常被人们与 FLOPs 概念混淆实际上 1 MACs 包含一个乘法运算与一个加法运算,大约包含 2 FLOPs。通常 MACs 与 FLOPs 存在 2 倍的关系。
然而,现实世界中的应用程序要复杂得多。考虑矩阵乘法的例子。A
是维数为 (m,n)
的矩阵,B
是 (n, 1)
的向量。
for i in range(m):
for j in range(n):
C[i][j] += A[i][j] * B[j] # one mul-add
会有 mn
个 MACs
和 2mn
个 FLOPs
。但是这样的实现是缓慢的,并行化是加快运行速度的必要条件:
for i in range(m):
parallelfor j in range(n):
d[j] = A[i][j] * B[j] # one mul
C[i][j] = sum(d) # n adds
那么 MACs
的数量就不再是 mn
。
当比较 MACs /FLOPs
时,希望这个数字与实现无关,并且尽可能一般化。因此在 THOP 中,只考虑乘法的次数,而忽略其他所有运算。
备注
FLOPs 近似等于乘法运算的 2 倍。
基本用法#
import torch
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
定义第三方模块的规则#
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
提高输出可读性#
回调 thop.clever_format
,以提供更好的输出格式。
from thop import clever_format
macs, params = clever_format([macs, params], "%.3f")
print("MACs: ", macs)
print("参数量: ", params)
MACs: 4.134G
参数量: 25.557M
基准#
from dataclasses import dataclass, asdict
@dataclass
class Info:
params: int # 参数量
macs: int
import torch
from torchvision import models
from thop.profile import profile
model_names = sorted(
name
for name in models.__dict__
if name.islower()
and not name.startswith("__") # and "inception" in name
and callable(models.__dict__[name])
)
# print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)"))
# print("---|---|---")
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
bunch = {}
for name in model_names:
model = models.__dict__[name]().to(device)
dsize = (1, 3, 224, 224)
if "inception" in name:
dsize = (1, 3, 299, 299)
inputs = torch.randn(dsize).to(device)
total_ops, total_params = profile(model, (inputs,), verbose=False)
bunch[name] = asdict(Info(total_params / (2 ** 20), total_ops / (2 ** 30)))
# print(
# "%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))
# )
/home/pc/.local/lib/python3.8/site-packages/torchvision/models/googlenet.py:77: FutureWarning: The default weight initialization of GoogleNet will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
/home/pc/.local/lib/python3.8/site-packages/torchvision/models/inception.py:80: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
import pandas as pd
df = pd.DataFrame(bunch).T
df.columns = ["Params(M)", "MACs(G)"]
df.round(2)
Params(M) | MACs(G) | |
---|---|---|
alexnet | 58.27 | 0.67 |
densenet121 | 7.61 | 2.70 |
densenet161 | 27.35 | 7.31 |
densenet169 | 13.49 | 3.20 |
densenet201 | 19.09 | 4.09 |
googlenet | 6.32 | 1.41 |
inception_v3 | 22.73 | 5.35 |
mnasnet0_5 | 2.12 | 0.11 |
mnasnet0_75 | 3.02 | 0.22 |
mnasnet1_0 | 4.18 | 0.31 |
mnasnet1_3 | 5.99 | 0.52 |
mobilenet_v2 | 3.34 | 0.30 |
mobilenet_v3_large | 5.23 | 0.22 |
mobilenet_v3_small | 2.43 | 0.06 |
resnet101 | 42.49 | 7.33 |
resnet152 | 57.40 | 10.81 |
resnet18 | 11.15 | 1.70 |
resnet34 | 20.79 | 3.43 |
resnet50 | 24.37 | 3.85 |
resnext101_32x8d | 84.68 | 15.40 |
resnext50_32x4d | 23.87 | 3.99 |
shufflenet_v2_x0_5 | 1.30 | 0.04 |
shufflenet_v2_x1_0 | 2.17 | 0.14 |
shufflenet_v2_x1_5 | 3.34 | 0.29 |
shufflenet_v2_x2_0 | 7.05 | 0.56 |
squeezenet1_0 | 1.19 | 0.76 |
squeezenet1_1 | 1.18 | 0.33 |
vgg11 | 126.71 | 7.09 |
vgg11_bn | 126.71 | 7.11 |
vgg13 | 126.88 | 10.53 |
vgg13_bn | 126.89 | 10.58 |
vgg16 | 131.95 | 14.41 |
vgg16_bn | 131.96 | 14.46 |
vgg19 | 137.01 | 18.28 |
vgg19_bn | 137.02 | 18.34 |
wide_resnet101_2 | 121.01 | 21.27 |
wide_resnet50_2 | 65.69 | 10.67 |