{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# THOP: PyTorch OpCounter\n", "\n", "参考:[pytorch-OpCounter](https://github.com/Lyken17/pytorch-OpCounter)\n", "\n", "## 知识点\n", "\n", "- FLOPS(Floating Point Operations Per Second):每秒浮点运算次数,是衡量硬件速度的指标。\n", "- FLOPs(Floating Point Operations):浮点运算次数,用来衡量模型计算复杂度,常用来做神经网络模型速度的间接衡量标准。FLOPS 与 FLOPs 常常被人们混淆使用。\n", "- MACs(Multiply–Accumulate Operations):乘加累积运算数(`a <- a + (b x c)`),常常被人们与 FLOPs 概念混淆实际上 1 MACs 包含一个乘法运算与一个加法运算,大约包含 2 FLOPs。通常 MACs 与 FLOPs 存在 2 倍的关系。\n", "\n", "然而,现实世界中的应用程序要复杂得多。考虑矩阵乘法的例子。`A` 是维数为 `(m,n)` 的矩阵,`B` 是 `(n, 1)`的向量。\n", "\n", "```python\n", "for i in range(m):\n", " for j in range(n):\n", " C[i][j] += A[i][j] * B[j] # one mul-add\n", "```\n", "\n", "会有 `mn`个 `MACs` 和 `2mn` 个 `FLOPs`。但是这样的实现是缓慢的,并行化是加快运行速度的必要条件:\n", "\n", "```python\n", "for i in range(m):\n", " parallelfor j in range(n):\n", " d[j] = A[i][j] * B[j] # one mul\n", " C[i][j] = sum(d) # n adds\n", "```\n", "\n", "那么 `MACs` 的数量就不再是 `mn`。\n", "\n", "当比较 `MACs /FLOPs` 时,希望这个数字与实现无关,并且尽可能一般化。因此在 THOP 中,只考虑乘法的次数,而忽略其他所有运算。\n", "\n", "```{note}\n", "FLOPs 近似等于乘法运算的 2 倍。\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 基本用法" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO] Register count_convNd() for .\n", "[INFO] Register count_normalization() for .\n", "[INFO] Register zero_ops() for .\n", "[INFO] Register zero_ops() for .\n", "[INFO] Register zero_ops() for .\n", "[INFO] Register count_adap_avgpool() for .\n", "[INFO] Register count_linear() for .\n" ] } ], "source": [ "import torch\n", "from torchvision.models import resnet50\n", "from thop import profile\n", "\n", "model = resnet50()\n", "input = torch.randn(1, 3, 224, 224)\n", "macs, params = profile(model, inputs=(input, ))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义第三方模块的规则\n", "\n", "\n", "```python\n", "class YourModule(nn.Module):\n", " # your definition\n", "def count_your_model(model, x, y):\n", " # your rule here\n", "\n", "input = torch.randn(1, 3, 224, 224)\n", "macs, params = profile(model, inputs=(input, ), \n", " custom_ops={YourModule: count_your_model})\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 提高输出可读性\n", "\n", "\n", "回调 `thop.clever_format`,以提供更好的输出格式。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MACs: 4.134G\n", "参数量: 25.557M\n" ] } ], "source": [ "from thop import clever_format\n", "\n", "macs, params = clever_format([macs, params], \"%.3f\")\n", "print(\"MACs: \", macs)\n", "print(\"参数量: \", params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 基准" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass, asdict\n", "\n", "@dataclass\n", "class Info:\n", " params: int # 参数量\n", " macs: int" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/pc/.local/lib/python3.8/site-packages/torchvision/models/googlenet.py:77: FutureWarning: The default weight initialization of GoogleNet will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.\n", " warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '\n", "/home/pc/.local/lib/python3.8/site-packages/torchvision/models/inception.py:80: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.\n", " warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '\n" ] } ], "source": [ "import torch\n", "from torchvision import models\n", "from thop.profile import profile\n", "\n", "model_names = sorted(\n", " name\n", " for name in models.__dict__\n", " if name.islower()\n", " and not name.startswith(\"__\") # and \"inception\" in name\n", " and callable(models.__dict__[name])\n", ")\n", "\n", "# print(\"%s | %s | %s\" % (\"Model\", \"Params(M)\", \"FLOPs(G)\"))\n", "# print(\"---|---|---\")\n", "\n", "device = \"cpu\"\n", "if torch.cuda.is_available():\n", " device = \"cuda\"\n", "\n", "bunch = {}\n", "for name in model_names:\n", " model = models.__dict__[name]().to(device)\n", " dsize = (1, 3, 224, 224)\n", " if \"inception\" in name:\n", " dsize = (1, 3, 299, 299)\n", " inputs = torch.randn(dsize).to(device)\n", " total_ops, total_params = profile(model, (inputs,), verbose=False)\n", " bunch[name] = asdict(Info(total_params / (2 ** 20), total_ops / (2 ** 30)))\n", " # print(\n", " # \"%s | %.2f | %.2f\" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))\n", " # )" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Params(M)MACs(G)
alexnet58.270.67
densenet1217.612.70
densenet16127.357.31
densenet16913.493.20
densenet20119.094.09
googlenet6.321.41
inception_v322.735.35
mnasnet0_52.120.11
mnasnet0_753.020.22
mnasnet1_04.180.31
mnasnet1_35.990.52
mobilenet_v23.340.30
mobilenet_v3_large5.230.22
mobilenet_v3_small2.430.06
resnet10142.497.33
resnet15257.4010.81
resnet1811.151.70
resnet3420.793.43
resnet5024.373.85
resnext101_32x8d84.6815.40
resnext50_32x4d23.873.99
shufflenet_v2_x0_51.300.04
shufflenet_v2_x1_02.170.14
shufflenet_v2_x1_53.340.29
shufflenet_v2_x2_07.050.56
squeezenet1_01.190.76
squeezenet1_11.180.33
vgg11126.717.09
vgg11_bn126.717.11
vgg13126.8810.53
vgg13_bn126.8910.58
vgg16131.9514.41
vgg16_bn131.9614.46
vgg19137.0118.28
vgg19_bn137.0218.34
wide_resnet101_2121.0121.27
wide_resnet50_265.6910.67
\n", "
" ], "text/plain": [ " Params(M) MACs(G)\n", "alexnet 58.27 0.67\n", "densenet121 7.61 2.70\n", "densenet161 27.35 7.31\n", "densenet169 13.49 3.20\n", "densenet201 19.09 4.09\n", "googlenet 6.32 1.41\n", "inception_v3 22.73 5.35\n", "mnasnet0_5 2.12 0.11\n", "mnasnet0_75 3.02 0.22\n", "mnasnet1_0 4.18 0.31\n", "mnasnet1_3 5.99 0.52\n", "mobilenet_v2 3.34 0.30\n", "mobilenet_v3_large 5.23 0.22\n", "mobilenet_v3_small 2.43 0.06\n", "resnet101 42.49 7.33\n", "resnet152 57.40 10.81\n", "resnet18 11.15 1.70\n", "resnet34 20.79 3.43\n", "resnet50 24.37 3.85\n", "resnext101_32x8d 84.68 15.40\n", "resnext50_32x4d 23.87 3.99\n", "shufflenet_v2_x0_5 1.30 0.04\n", "shufflenet_v2_x1_0 2.17 0.14\n", "shufflenet_v2_x1_5 3.34 0.29\n", "shufflenet_v2_x2_0 7.05 0.56\n", "squeezenet1_0 1.19 0.76\n", "squeezenet1_1 1.18 0.33\n", "vgg11 126.71 7.09\n", "vgg11_bn 126.71 7.11\n", "vgg13 126.88 10.53\n", "vgg13_bn 126.89 10.58\n", "vgg16 131.95 14.41\n", "vgg16_bn 131.96 14.46\n", "vgg19 137.01 18.28\n", "vgg19_bn 137.02 18.34\n", "wide_resnet101_2 121.01 21.27\n", "wide_resnet50_2 65.69 10.67" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "df = pd.DataFrame(bunch).T\n", "df.columns = [\"Params(M)\", \"MACs(G)\"]\n", "df.round(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.10 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 2 }