{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自定义量化" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# 提供注解的向前兼容\n", "from __future__ import annotations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 量化流程" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{important}\n", "将模型转换为 **可量化模型** 的流程如下:\n", "\n", "1. 用 {class}`~torch.nn.quantized.FloatFunctional` 替换加法\n", "2. 使用 {func}`~torch.ao.quantization.fuse_modules.fuse_modules` 或者 {func}`~torch.ao.quantization.fuse_modules.fuse_modules_qat` 融合如下模块序列:\n", "\n", " - conv, bn\n", " - conv, bn, relu\n", " - conv, relu\n", " - linear, bn\n", " - linear, relu\n", "\n", "3. 在网络的开头和结尾分别插入 {class}`~torch.ao.quantization.stubs.QuantStub` 和 {class}`~torch.ao.quantization.stubs.DeQuantStub`\n", "4. 将 {class}`torch.nn.ReLU6` 替换为 {class}`torch.nn.ReLU`\n", "```\n", "\n", "```{rubric} 载入一些库\n", "```" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "'''参考 torchvision/models/quantization/mobilenetv2.py\n", "'''\n", "from typing import Any\n", "from torch import Tensor\n", "from torch import nn\n", "\n", "from torchvision._internally_replaced_utils import load_state_dict_from_url\n", "from torchvision.ops.misc import ConvNormActivation\n", "from torchvision.models.quantization.utils import _fuse_modules, _replace_relu, quantize_model\n", "from torch.ao.quantization import QuantStub, DeQuantStub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 融合模块\n", "```\n", "\n", "````{note}\n", "\n", "{func}`~torchvision.models.quantization.utils._fuse_modules` 提供了 {func}`~torch.ao.quantization.fuse_modules.fuse_modules` 和 {func}`~torch.ao.quantization.fuse_modules.fuse_modules_qat` 的统一接口。\n", "\n", "```python\n", "from torch.ao.quantization import fuse_modules_qat, fuse_modules\n", "\n", "\n", "def _fuse_modules(\n", " model: nn.Module, \n", " modules_to_fuse: list[str] | list[list[str]], \n", " is_qat: bool | None, \n", " **kwargs: Any\n", "):\n", " if is_qat is None:\n", " is_qat = model.training\n", " method = fuse_modules_qat if is_qat else fuse_modules\n", " return method(model, modules_to_fuse, **kwargs)\n", "```\n", "\n", "\n", "{class}`~torch.nn.quantized.FloatFunctional` 算子比普通的 `torch.` 的运算多了后处理操作,比如:\n", "\n", "```python\n", "class FloatFunctional(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.activation_post_process = torch.nn.Identity()\n", "\n", " def forward(self, x):\n", " raise RuntimeError(\"FloatFunctional is not intended to use the \" +\n", " \"'forward'. Please use the underlying operation\")\n", "\n", " r\"\"\"Operation equivalent to ``torch.add(Tensor, Tensor)``\"\"\"\n", " def add(self, x: Tensor, y: Tensor) -> Tensor:\n", " r = torch.add(x, y)\n", " r = self.activation_post_process(r)\n", " return r\n", "```\n", "\n", "由于 `self.activation_post_process = torch.nn.Identity()` 是自映射,所以 {meth}`~torch.nn.quantized.FloatFunctional.add` 等价于 {func}`torch.add`。\n", "````\n", "\n", "```{tip}\n", "猜测 {class}`~torch.nn.quantized.FloatFunctional` 算子提供了自定义算子的官方接口。即只需要对 `self.activation_post_process` 赋值即可添加算子的后处理工作。\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 特定于 `MobileNetV2` 的量化\n", "\n", "下面以 {class}`~torchvision.models.MobileNetV2` 为例,介绍如何将其转换为量化模型 {class}`~torchvision.models.quantization.QuantizableMobileNetV2`。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls\n", "\n", "quant_model_urls = {\n", " \"mobilenet_v2_qnnpack\": \"https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth\"\n", "}\n", "\n", "\n", "class QuantizableInvertedResidual(InvertedResidual):\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " super().__init__(*args, **kwargs)\n", " self.skip_add = nn.quantized.FloatFunctional()\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " if self.use_res_connect:\n", " return self.skip_add.add(x, self.conv(x))\n", " else:\n", " return self.conv(x)\n", "\n", " def fuse_model(self, is_qat: bool | None = None) -> None:\n", " for idx in range(len(self.conv)):\n", " if type(self.conv[idx]) is nn.Conv2d:\n", " _fuse_modules(self.conv,\n", " [str(idx),\n", " str(idx + 1)],\n", " is_qat,\n", " inplace=True)\n", "\n", "\n", "class QuantizableMobileNetV2(MobileNetV2):\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " \"\"\"\n", " MobileNet V2 main class\n", "\n", " Args:\n", " 继承自浮点 MobileNetV2 的参数\n", " \"\"\"\n", " super().__init__(*args, **kwargs)\n", " self.quant = QuantStub()\n", " self.dequant = DeQuantStub()\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " x = self.quant(x)\n", " x = self._forward_impl(x)\n", " x = self.dequant(x)\n", " return x\n", "\n", " def fuse_model(self, is_qat: bool | None) -> None:\n", " for m in self.modules():\n", " if type(m) is ConvNormActivation:\n", " _fuse_modules(m, [\"0\", \"1\", \"2\"], is_qat, inplace=True)\n", " if type(m) is QuantizableInvertedResidual:\n", " m.fuse_model(is_qat)\n", "\n", "\n", "def mobilenet_v2(\n", " pretrained: bool = False,\n", " progress: bool = True,\n", " quantize: bool = False,\n", " **kwargs: Any,\n", ") -> QuantizableMobileNetV2:\n", " \"\"\"\n", " 从 `MobileNetV2:反向残差和线性瓶颈 `_ 构建 MobileNetV2 架构。\n", "\n", " 注意,quantize = True 返回具有 8 bit 权值的量化模型。量化模型只支持推理并在 CPU 上运行。\n", " 目前还不支持 GPU 推理\n", "\n", " Args:\n", " pretrained (bool): 如果为 True,返回在 ImageNet 上训练过的模型。\n", " progress (bool): 如果为 True,则显示下载到标准错误的进度条\n", " quantize(bool): 如果为 True,则返回量化模型,否则返回浮点模型\n", " \"\"\"\n", " model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)\n", " _replace_relu(model)\n", "\n", " if quantize:\n", " # TODO use pretrained as a string to specify the backend\n", " backend = \"qnnpack\"\n", " quantize_model(model, backend)\n", " else:\n", " assert pretrained in [True, False]\n", "\n", " if pretrained:\n", " if quantize:\n", " model_url = quant_model_urls[\"mobilenet_v2_\" + backend]\n", " else:\n", " model_url = model_urls[\"mobilenet_v2\"]\n", "\n", " state_dict = load_state_dict_from_url(model_url, progress=progress)\n", " model.load_state_dict(state_dict)\n", " return model" ] } ], "metadata": { "interpreter": { "hash": "78526419bf48930935ba7e23437b2460cb231485716b036ebb8701887a294fa8" }, "kernelspec": { "display_name": "Python 3.10.0 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }