THOP: PyTorch OpCounter#

参考:pytorch-OpCounter

知识点#

  • FLOPS(Floating Point Operations Per Second):每秒浮点运算次数,是衡量硬件速度的指标。

  • FLOPs(Floating Point Operations):浮点运算次数,用来衡量模型计算复杂度,常用来做神经网络模型速度的间接衡量标准。FLOPS 与 FLOPs 常常被人们混淆使用。

  • MACs(Multiply–Accumulate Operations):乘加累积运算数(a <- a + (b x c)),常常被人们与 FLOPs 概念混淆实际上 1 MACs 包含一个乘法运算与一个加法运算,大约包含 2 FLOPs。通常 MACs 与 FLOPs 存在 2 倍的关系。

然而,现实世界中的应用程序要复杂得多。考虑矩阵乘法的例子。A 是维数为 (m,n) 的矩阵,B(n, 1)的向量。

for i in range(m):
    for j in range(n):
        C[i][j] += A[i][j] * B[j] # one mul-add

会有 mnMACs2mnFLOPs。但是这样的实现是缓慢的,并行化是加快运行速度的必要条件:

for i in range(m):
  parallelfor j in range(n):
      d[j] = A[i][j] * B[j] # one mul
  C[i][j] = sum(d) # n adds

那么 MACs 的数量就不再是 mn

当比较 MACs /FLOPs 时,希望这个数字与实现无关,并且尽可能一般化。因此在 THOP 中,只考虑乘法的次数,而忽略其他所有运算。

备注

FLOPs 近似等于乘法运算的 2 倍。

基本用法#

import torch
from torchvision.models import resnet50
from thop import profile

model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.

定义第三方模块的规则#

class YourModule(nn.Module):
    # your definition
def count_your_model(model, x, y):
    # your rule here

input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ), 
                        custom_ops={YourModule: count_your_model})

提高输出可读性#

回调 thop.clever_format,以提供更好的输出格式。

from thop import clever_format

macs, params = clever_format([macs, params], "%.3f")
print("MACs: ", macs)
print("参数量: ", params)
MACs:  4.134G
参数量:  25.557M

基准#

from dataclasses import dataclass, asdict

@dataclass
class Info:
    params: int # 参数量
    macs: int
import torch
from torchvision import models
from thop.profile import profile

model_names = sorted(
    name
    for name in models.__dict__
    if name.islower()
    and not name.startswith("__")  # and "inception" in name
    and callable(models.__dict__[name])
)

# print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)"))
# print("---|---|---")

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

bunch = {}
for name in model_names:
    model = models.__dict__[name]().to(device)
    dsize = (1, 3, 224, 224)
    if "inception" in name:
        dsize = (1, 3, 299, 299)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs,), verbose=False)
    bunch[name] = asdict(Info(total_params / (2 ** 20), total_ops / (2 ** 30)))
    # print(
    #     "%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))
    # )
/home/pc/.local/lib/python3.8/site-packages/torchvision/models/googlenet.py:77: FutureWarning: The default weight initialization of GoogleNet will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
  warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
/home/pc/.local/lib/python3.8/site-packages/torchvision/models/inception.py:80: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
  warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
import pandas as pd

df = pd.DataFrame(bunch).T
df.columns = ["Params(M)", "MACs(G)"]
df.round(2)
Params(M) MACs(G)
alexnet 58.27 0.67
densenet121 7.61 2.70
densenet161 27.35 7.31
densenet169 13.49 3.20
densenet201 19.09 4.09
googlenet 6.32 1.41
inception_v3 22.73 5.35
mnasnet0_5 2.12 0.11
mnasnet0_75 3.02 0.22
mnasnet1_0 4.18 0.31
mnasnet1_3 5.99 0.52
mobilenet_v2 3.34 0.30
mobilenet_v3_large 5.23 0.22
mobilenet_v3_small 2.43 0.06
resnet101 42.49 7.33
resnet152 57.40 10.81
resnet18 11.15 1.70
resnet34 20.79 3.43
resnet50 24.37 3.85
resnext101_32x8d 84.68 15.40
resnext50_32x4d 23.87 3.99
shufflenet_v2_x0_5 1.30 0.04
shufflenet_v2_x1_0 2.17 0.14
shufflenet_v2_x1_5 3.34 0.29
shufflenet_v2_x2_0 7.05 0.56
squeezenet1_0 1.19 0.76
squeezenet1_1 1.18 0.33
vgg11 126.71 7.09
vgg11_bn 126.71 7.11
vgg13 126.88 10.53
vgg13_bn 126.89 10.58
vgg16 131.95 14.41
vgg16_bn 131.96 14.46
vgg19 137.01 18.28
vgg19_bn 137.02 18.34
wide_resnet101_2 121.01 21.27
wide_resnet50_2 65.69 10.67