{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PTQ 与 QAT 实践\n", "\n", "本文主要介绍如何使用 PyTorch 将浮点模型转换为 PTQ 或者 QAT 模型。\n", "\n", "## 背景\n", "\n", "{guilabel}`目标`:快速将浮点模型转换为 PTQ 或者 QAT 模型。\n", "\n", "### 读者\n", "\n", "本教程适用于会使用 PyTorch 编写 CNN 等模块的的算法工程师。\n", "\n", "### 环境配置\n", "\n", "本文使用 Python 3.10.0 (其他版本请自测),暂时仅 Linux 平台被测试。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看 `torch` 和 `torchvision` 的版本:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch: 1.11.0 \n", "torchvision: 0.12.0\n" ] } ], "source": [ "import torch\n", "import torchvision\n", "\n", "print(f'torch: {torch.__version__} \\n'\n", " f'torchvision: {torchvision.__version__}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "设置一些警告配置:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 设置 warnings\n", "import warnings\n", "warnings.filterwarnings(\n", " action='ignore',\n", " category=DeprecationWarning,\n", " module='.*'\n", ")\n", "warnings.filterwarnings(\n", " action='ignore',\n", " module='torch.ao.quantization'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 概述:PQT 与 QAT\n", "\n", "参考:[量化](https://pytorch.org/docs/master/quantization.html)\n", "\n", "`训练后量化`\n", ": 简称 PTQ(Post Training Quantization):权重量化,激活量化,需要借助数据在训练后进行校准。\n", "\n", "`静态量化感知训练`\n", ": 简称 QAT(static quantization aware training):权重量化,激活量化,在训练过程中的量化数值进行建模。\n", "\n", "`浮点模型`\n", ": 模型的 **权重** 和 **激活** 均为浮点类型(如 {data}`torch.float32`, {data}`torch.float64`)。\n", "\n", "`量化模型`\n", ": 模型的 **权重** 和 **激活** 均为量化类型(如 {data}`torch.qint32`, {data}`torch.qint8`, {data}`torch.quint8`, {data}`torch.quint2x4`, {data}`torch.quint4x2`)。\n", "\n", "\n", "下面举例说明如何将浮点模型转换为量化模型。\n", "\n", "为了方便说明定义如下模块:\n", "\n", "```{rubric} 定义简单的浮点模块\n", "```" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torch import nn, Tensor\n", "\n", "\n", "class M(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv = torch.nn.Conv2d(1, 1, 1)\n", " self.relu = torch.nn.ReLU()\n", "\n", " def _forward_impl(self, x: Tensor) -> Tensor:\n", " '''提供便捷函数'''\n", " x = self.conv(x)\n", " x = self.relu(x)\n", " return x\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " x= self._forward_impl(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 定义可量化模块\n", "```\n", "\n", "将浮点模块 `M` 转换为可量化模块 `QM`(量化流程的最关键的一步)。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from torch.ao.quantization import QuantStub, DeQuantStub\n", "\n", "\n", "class QM(M):\n", " '''\n", " Args:\n", " is_print: 为了测试需求,打印一些信息\n", " '''\n", " def __init__(self, is_print: bool=False):\n", " super().__init__()\n", " self.is_print = is_print\n", " self.quant = QuantStub() # 将张量从浮点转换为量化\n", " self.dequant = DeQuantStub() # 将张量从量化转换为浮点\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " # 手动指定张量将在量化模型中从浮点模块转换为量化模块的位置\n", " x = self.quant(x)\n", " if self.is_print:\n", " print('量化前的类型:', x.dtype)\n", " x = self._forward_impl(x)\n", " if self.is_print:\n", " print('量化中的类型:',x.dtype)\n", " # 在量化模型中手动指定张量从量化到浮点的转换位置\n", " x = self.dequant(x)\n", " if self.is_print:\n", " print('量化后的类型:', x.dtype)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "简单测试前向过程的激活数据类型:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "量化前的类型: torch.float32\n", "量化中的类型: torch.float32\n", "量化后的类型: torch.float32\n" ] } ], "source": [ "input_fp32 = torch.randn(4, 1, 4, 4) # 输入的数据\n", "\n", "m = QM(is_print=True)\n", "x = m(input_fp32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看权重的数据类型:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.float32" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.conv.weight.dtype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看出,此时模块 `m` 是浮点模块。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PTQ 简介\n", "\n", "当内存带宽和计算空间都很重要时,通常会使用训练后量化,而 CNN 就是其典型的用例。训练后量化对模型的 **权重** 和 **激活** 进行量化。它在可能的情况下将 **激活** 融合到前面的层中。它需要用具有代表性的数据集进行 **校准**,以确定激活的最佳量化参数。\n", "\n", "```{rubric} 示意图\n", "```\n", "\n", "```\n", "# 原始模型\n", "# 全部的张量和计算均在浮点上进行\n", "previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32\n", " /\n", " linear_weight_fp32\n", "\n", "# 静态量化模型\n", "# weights 和 activations 在 int8 上\n", "previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8\n", " /\n", " linear_weight_int8\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "直接创建浮点模块的实例:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# 创建浮点模型实例\n", "model_fp32 = QM(is_print=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "要使 PTQ 生效,必须将模型设置为 `eval` 模式:\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QM(\n", " (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " (quant): QuantStub()\n", " (dequant): DeQuantStub()\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_fp32.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看此时的数据类型:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "量化前的类型: torch.float32\n", "量化中的类型: torch.float32\n", "量化后的类型: torch.float32\n", "激活和权重的数据类型分别为:torch.float32, torch.float32\n" ] } ], "source": [ "input_fp32 = torch.randn(4, 1, 4, 4)\n", "\n", "x = model_fp32(input_fp32)\n", "print('激活和权重的数据类型分别为:'\n", " f'{x.dtype}, {model_fp32.conv.weight.dtype}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 定义观测器\n", "```\n", "\n", "赋值实例变量 `qconfig`,其中包含关于要附加哪种观测器的信息:\n", "\n", "- 使用 [`'fbgemm'`](https://github.com/pytorch/FBGEMM) 用于带 AVX2 的 x86(没有AVX2,一些运算的实现效率很低);使用 [`'qnnpack'`](https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu/qnnpack) 用于 ARM CPU(通常出现在移动/嵌入式设备中)。\n", "- 其他量化配置,如选择对称或非对称量化和 `MinMax` 或 `L2Norm` 校准技术,可以在这里指定。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看此时的数据类型:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "量化前的类型: torch.float32\n", "量化中的类型: torch.float32\n", "量化后的类型: torch.float32\n", "激活和权重的数据类型分别为:torch.float32, torch.float32\n" ] } ], "source": [ "input_fp32 = torch.randn(4, 1, 4, 4)\n", "\n", "x = model_fp32(input_fp32)\n", "print('激活和权重的数据类型分别为:'\n", " f'{x.dtype}, {model_fp32.conv.weight.dtype}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 融合激活层\n", "```\n", "\n", "在适用的地方,融合 activation 到前面的层(这需要根据模型架构手动完成)。常见的融合包括 `conv + relu` 和 `conv + batchnorm + relu`。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QM(\n", " (conv): ConvReLU2d(\n", " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", " (1): ReLU()\n", " )\n", " (relu): Identity()\n", " (quant): QuantStub()\n", " (dequant): DeQuantStub()\n", ")" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,\n", " [['conv', 'relu']])\n", " \n", "model_fp32_fused" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到 `model_fp32_fused` 中 `ConvReLU2d` 融合 `model_fp32` 的两个层 `conv` 和 `relu`。\n", "\n", "查看此时的数据类型:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "量化前的类型: torch.float32\n", "量化中的类型: torch.float32\n", "量化后的类型: torch.float32\n", "激活和权重的数据类型分别为:torch.float32, torch.float32\n" ] } ], "source": [ "input_fp32 = torch.randn(4, 1, 4, 4)\n", "\n", "x = model_fp32_fused(input_fp32)\n", "print('激活和权重的数据类型分别为:'\n", " f'{x.dtype}, {model_fp32.conv.weight.dtype}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 启用观测器\n", "```\n", "\n", "在融合后的模块中启用观测器,用于在校准期间观测激活(activation)张量。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 校准准备好的模型\n", "```\n", "校准准备好的模型,以确定量化参数的激活在现实世界的设置,校准具有代表性的数据集。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "量化前的类型: torch.float32\n", "量化中的类型: torch.float32\n", "量化后的类型: torch.float32\n", "激活和权重的数据类型分别为:torch.float32, torch.float32\n" ] } ], "source": [ "input_fp32 = torch.randn(4, 1, 4, 4)\n", "\n", "x = model_fp32_prepared(input_fp32)\n", "print('激活和权重的数据类型分别为:'\n", " f'{x.dtype}, {model_fp32.conv.weight.dtype}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 模型转换\n", "```\n", "\n", "```{note}\n", "量化权重,计算和存储每个激活张量要使用的尺度(scale)和偏差(bias)值,并用量化实现替换关键算子。\n", "```\n", "\n", "转换已校准好的模型为量化模型:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QM(\n", " (conv): QuantizedConvReLU2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.010650944896042347, zero_point=0)\n", " (relu): Identity()\n", " (quant): Quantize(scale=tensor([0.0351]), zero_point=tensor([74]), dtype=torch.quint8)\n", " (dequant): DeQuantize()\n", ")" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_int8 = torch.quantization.convert(model_fp32_prepared)\n", "model_int8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看权重的数据类型:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.qint8" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_int8.conv.weight().dtype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看出此时权重的元素大小为 1 字节,而不是 FP32 的 4 字节:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_int8.conv.weight().element_size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "运行模型,相关的计算将在 {data}`torch.qint8` 中发生。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "量化前的类型: torch.quint8\n", "量化中的类型: torch.quint8\n", "量化后的类型: torch.float32\n" ] }, { "data": { "text/plain": [ "torch.float32" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res = model_int8(input_fp32)\n", "res.dtype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "要了解更多关于量化意识训练的信息,请参阅 [QAT 教程](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### QAT 概述\n", "\n", "与其他量化方法相比,QAT 在 **训练过程中** 模拟量化的效果,可以获得更高的 accuracy。在训练过程中,所有的计算都是在浮点上进行的,使用 fake_quant 模块通过夹紧和舍入的方式对量化效果进行建模,模拟 INT8 的效果。模型转换后,权值和激活被量化,激活在可能的情况下被融合到前一层。它通常与 CNN 一起使用,与 PTQ 相比具有更高的 accuracy。\n", "\n", "\n", "```{rubric} 示意图\n", "```\n", "\n", "```\n", "# 原始模型\n", "# 全部张量和计算均在浮点上\n", "previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32\n", " /\n", " linear_weight_fp32\n", "\n", "# 在训练过程中使用 fake_quants 建模量化数值\n", "previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32\n", " /\n", " linear_weight_fp32 -- fq\n", "\n", "# 量化模型\n", "# weights 和 activations 在 int8 上\n", "previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8\n", " /\n", " linear_weight_int8\n", "```\n", "\n", "定义比 `M` 稍微复杂一点的浮点模块:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class M2(M):\n", " def __init__(self):\n", " super().__init__()\n", " self.bn = torch.nn.BatchNorm2d(1)\n", "\n", " def _forward_impl(self, x: Tensor) -> Tensor:\n", " '''提供便捷函数'''\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " x = self.relu(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "同样需要定义可量化模块:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class QM2(M2, QM):\n", " def __init__(self):\n", " super().__init__()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建浮点模型实例:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QM2(\n", " (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " (quant): QuantStub()\n", " (dequant): DeQuantStub()\n", " (bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建模型实例\n", "model_fp32 = QM2()\n", "model_fp32" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "模型必须设置为训练模式,以便 QAT 可用:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "model_fp32.train();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "添加量化配置(与 PTQ 相同相似):" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 融合 QAT 模块\n", "```\n", "\n", "QAT 的模块融合与 PTQ 相同相似:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from torch.ao.quantization import fuse_modules_qat\n", "\n", "model_fp32_fused = fuse_modules_qat(model_fp32,\n", " [['conv', 'bn', 'relu']])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 准备 QAT 模型\n", "```\n", "\n", "这将在模型中插入观测者和伪量化模块,它们将在校准期间观测权重和激活的张量。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 训练 QAT 模型\n", "```\n", "\n", "```python\n", "# 下文会编写实际的例子,此处没有显示\n", "training_loop(model_fp32_prepared)\n", "```\n", "\n", "将观测到的模型转换为量化模型。需要:\n", "\n", "- 量化权重,计算和存储用于每个激活张量的尺度(scale)和偏差(bias)值,\n", "- 在适当的地方融合模块,并用量化实现替换关键算子。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "model_fp32_prepared.eval()\n", "model_int8 = torch.quantization.convert(model_fp32_prepared)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "运行模型,相关的计算将在 {data}`torch.qint8` 中发生。" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "res = model_int8(input_fp32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "要了解更多关于量化意识训练的信息,请参阅 [QAT 教程](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)。\n", "\n", "### PTQ/QAT 统一的量化流程\n", "\n", "PTQ 和 QAT 的量化流程十分相似,为了统一接口,可以使用 `torchvision` 提供的函数 {func}`~torchvision.models.quantization.utils._fuse_modules`。\n", "\n", "下面利用函数 {func}`~torchvision.models.quantization.utils._fuse_modules` 可量化模块 `QM2`。" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "from typing import Any\n", "from torch.ao.quantization import fuse_modules, fuse_modules_qat\n", "from torch.ao.quantization import get_default_qconfig, get_default_qat_qconfig\n", "from torch.ao.quantization import quantize, quantize_qat\n", "\n", "def _fuse_modules(\n", " model: nn.Module, modules_to_fuse: list[str] | list[list[str]], is_qat: bool | None, **kwargs: Any\n", "):\n", " if is_qat is None:\n", " is_qat = model.training\n", " method = fuse_modules_qat if is_qat else fuse_modules\n", " return method(model, modules_to_fuse, **kwargs)\n", "\n", "\n", "class QM3(QM2):\n", " '''可量化模型\n", " Args:\n", " is_qat: 是否使用 QAT 模式\n", " '''\n", " def __init__(self, is_qat: bool | None = None, backend='fbgemm'):\n", " super().__init__()\n", " self.is_qat = is_qat\n", " # 定义观测器\n", " if is_qat:\n", " self.train()\n", " self.qconfig = get_default_qat_qconfig(backend)\n", " else:\n", " self.eval()\n", " self.qconfig = get_default_qconfig(backend)\n", "\n", " def fuse_model(self) -> None:\n", " '''模块融合'''\n", " if self.is_qat:\n", " modules_to_fuse = ['bn', 'relu']\n", " else:\n", " modules_to_fuse = ['conv', 'bn', 'relu']\n", " return _fuse_modules(self,\n", " modules_to_fuse,\n", " self.is_qat,\n", " inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "有了可量化模块 `QM3`,可以十分便利的切换 PTQ/QAT了。\n", "\n", "比如,PTQ,可以这样:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "def run_fn(model, num_epochs):\n", " for _ in range(num_epochs):\n", " input_fp32 = torch.randn(4, 1, 4, 4)\n", " model(input_fp32)\n", "\n", "num_epochs = 10\n", "ptq_model = QM3(is_qat=False)\n", "model_fused = ptq_model.fuse_model()\n", "quanted_model = quantize(model_fused, run_fn, [num_epochs])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "QAT 可以这样:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "num_epochs = 10\n", "qat_model = QM3(is_qat=True)\n", "model_fused = qat_model.fuse_model()\n", "quanted_model = quantize_qat(model_fused, run_fn, [num_epochs])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PTQ/QAT 量化策略\n", "\n", "对于通用量化技术,需要了解:\n", "\n", "1. 将任何需要输出再量化请求的运算(因此有额外的参数)从函数形式转换为模块形式(例如,使用 {class}`torch.nn.ReLU` 而不是 {func}`torch.nn.functional.relu`)。\n", "1. 通过在子模块上指定 `.qconfig` 属性或指定 `qconfig_dict` 来指定模型的哪些部分需要量化。例如,设置 `model.conv1.qconfig = None` 表示 `model.conv1` 层不量化,设置 `model.linear1.qconfig = custom_qconfig` 表示 `model.linear1` 将使用 `custom_qconfig` 而不是全局 `qconfig`。\n", "\n", "对于量化激活的静态量化技术(即对模型的权重和激活均进行量化,包括 PTQ 和 QAT),用户还需要做以下工作:\n", "\n", "1. 指定量化和反量化激活的位置。这是使用 {class}`~torch.ao.quantization.stubs.QuantStub` 和 {class}`~torch.ao.quantization.stubs.DeQuantStub` 模块完成的。\n", "1. 使用 {class}`~torch.nn.quantized.FloatFunctional` 将需要对量化进行特殊处理的张量运算封装到模块中。例如像 {func}`add` 和 {func}`cat` 这样需要特殊处理来确定输出量化参数的运算。\n", "1. 融合模块:将运算/模块组合成单个模块,获得更高的 accuracy 和性能。这是使用 {func}`~torch.ao.quantization.fuse_modules.fuse_modules` API 完成的,该 API 接受要融合的模块列表。目前支持以下融合:`[Conv, Relu]`、 `[Conv, BatchNorm]`、 `[Conv, BatchNorm, Relu]` 和 `[Linear, Relu]`。\n", "\n", "示例:\n", "\n", "```{figure} images/resnet.png\n", ":align: center\n", ":class: w3-border\n", "\n", "倒置残差块的转换前后对比\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PTQ 和 QAT 实战\n", "\n", "\n", "```{rubric} 模型对比\n", "```\n", "\n", "类型|大小(MB)|accuracy($\\%$)\n", ":-|:-|:-\n", "浮点|9.188|94.91\n", "浮点融合|8.924|94.91\n", "QAT|2.657|94.41\n", "\n", "```{rubric} 不同 QConfig 的静态 PTQ 模型\n", "```\n", "\n", "accuracy($\\%$)|激活|权重|\n", ":-|:-|:-\n", "|51.11|{data}`~torch.ao.quantization.observer.MinMaxObserver`.`with_args(quant_min=0, quant_max=127)`|{data}`~torch.ao.quantization.observer.MinMaxObserver`.`with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)`\n", "80.42|{data}`~torch.ao.quantization.observer.HistogramObserver`.`with_args(quant_min=0, quant_max=127)`|{data}`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.`with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)`\n", "\n", "为了提供一致的量化工具接口,我们使用 Python 包 `torchq`。\n", "\n", "本地载入临时 `torchq` 包:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mod import torchq" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{tip}\n", "本文使用 `torchq` 的 `'0.0.1-alpha'` 版本。\n", "```\n", "\n", "更方便的是:使用 `pip` 安装:\n", "\n", "```shell\n", "pip install torchq==0.0.1-alpha\n", "```\n", "\n", "接着,便可以直接导入:\n", "\n", "```python\n", "import torchq\n", "```\n", "\n", "```{tip}\n", "本文使用 `torchq` 的 `'0.0.1-alpha'` 版本。\n", "```\n", "\n", "可以看出 PTQ 和 QAT 需要用户自定义的内容主要集中在: **模块融合** 和 **算子替换**。\n", "\n", "{func}`~torchvision.models.quantization.utils._fuse_modules` 提供了 {func}`~torch.ao.quantization.fuse_modules.fuse_modules` 和 {func}`~torch.ao.quantization.fuse_modules.fuse_modules_qat` 的统一接口。下面以 MobileNetV2 为例,简述如何使用 {func}`~torchvision.models.quantization.utils._fuse_modules` 函数和 {class}`~torch.nn.quantized.FloatFunctional` 类定制可量化的模块。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "'''参考 torchvision/models/quantization/mobilenetv2.py\n", "'''\n", "from typing import Any\n", "from torch import Tensor\n", "from torch import nn\n", "\n", "from torchvision._internally_replaced_utils import load_state_dict_from_url\n", "from torchvision.ops.misc import ConvNormActivation\n", "from torchvision.models.quantization.utils import _fuse_modules, _replace_relu, quantize_model\n", "from torch.ao.quantization import QuantStub, DeQuantStub\n", "from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls\n", "\n", "quant_model_urls = {\n", " \"mobilenet_v2_qnnpack\": \"https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth\"\n", "}\n", "\n", "\n", "class QuantizableInvertedResidual(InvertedResidual):\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " super().__init__(*args, **kwargs)\n", " self.skip_add = nn.quantized.FloatFunctional()\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " if self.use_res_connect:\n", " return self.skip_add.add(x, self.conv(x))\n", " else:\n", " return self.conv(x)\n", "\n", " def fuse_model(self, is_qat: bool | None = None) -> None:\n", " for idx in range(len(self.conv)):\n", " if type(self.conv[idx]) is nn.Conv2d:\n", " _fuse_modules(self.conv,\n", " [str(idx),\n", " str(idx + 1)],\n", " is_qat,\n", " inplace=True)\n", "\n", "\n", "class QuantizableMobileNetV2(MobileNetV2):\n", " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", " \"\"\"\n", " MobileNet V2 main class\n", "\n", " Args:\n", " 继承自浮点 MobileNetV2 的参数\n", " \"\"\"\n", " super().__init__(*args, **kwargs)\n", " self.quant = QuantStub()\n", " self.dequant = DeQuantStub()\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " x = self.quant(x)\n", " x = self._forward_impl(x)\n", " x = self.dequant(x)\n", " return x\n", "\n", " def fuse_model(self, is_qat: bool | None=None) -> None:\n", " for m in self.modules():\n", " if type(m) is ConvNormActivation:\n", " _fuse_modules(m, [\"0\", \"1\", \"2\"], is_qat, inplace=True)\n", " if type(m) is QuantizableInvertedResidual:\n", " m.fuse_model(is_qat)\n", "\n", "\n", "def mobilenet_v2(\n", " pretrained: bool = False,\n", " progress: bool = True,\n", " quantize: bool = False,\n", " **kwargs: Any,\n", ") -> QuantizableMobileNetV2:\n", " \"\"\"\n", " 从 `MobileNetV2:反向残差和线性瓶颈 `_ 构建 MobileNetV2 架构。\n", "\n", " 注意,quantize = True 返回具有 8 bit 权值的量化模型。量化模型只支持推理并在 CPU 上运行。\n", " 目前还不支持 GPU 推理\n", "\n", " Args:\n", " pretrained (bool): 如果为 True,返回在 ImageNet 上训练过的模型。\n", " progress (bool): 如果为 True,则显示下载到标准错误的进度条\n", " quantize(bool): 如果为 True,则返回量化模型,否则返回浮点模型\n", " \"\"\"\n", " model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)\n", " _replace_relu(model)\n", "\n", " if quantize:\n", " # TODO use pretrained as a string to specify the backend\n", " backend = \"qnnpack\"\n", " quantize_model(model, backend)\n", " else:\n", " assert pretrained in [True, False]\n", "\n", " if pretrained:\n", " if quantize:\n", " model_url = quant_model_urls[\"mobilenet_v2_\" + backend]\n", " else:\n", " model_url = model_urls[\"mobilenet_v2\"]\n", "\n", " state_dict = load_state_dict_from_url(model_url, progress=progress)\n", " model.load_state_dict(state_dict)\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 一些准备工作\n", "\n", "下面以 Cifar10 为了来说明 PTQ/QAT 的量化流程。\n", "\n", "定义几个[辅助函数](https://github.com/pytorch/examples/blob/master/imagenet/main.py)来帮助评估模型。" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "from torchq.helper import evaluate, print_size_of_model, load_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "设置超参数:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "saved_model_dir = 'models/'\n", "float_model_file = 'mobilenet_pretrained_float.pth'\n", "scripted_float_model_file = 'mobilenet_float_scripted.pth'\n", "scripted_ptq_model_file = 'mobilenet_ptq_scripted.pth'\n", "scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'\n", "scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'\n", "\n", "learning_rate = 5e-5\n", "num_epochs = 30\n", "batch_size = 16\n", "num_classes = 10\n", "\n", "# 设置评估策略\n", "criterion = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义数据集和数据加载器:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "from torchq.xinet import CV\n", "\n", "# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224\n", "train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,\n", " resize=224)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看数据集的 batch 次数:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "训练、测试批次分别为: 3125 625\n" ] } ], "source": [ "print('训练、测试批次分别为:',\n", " len(train_iter), len(test_iter))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "获取训练和测试数据集的大小:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(50000, 10000)" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_train = sum(len(ys) for _, ys in train_iter)\n", "num_eval = sum(len(ys) for _, ys in test_iter)\n", "num_train, num_eval" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 微调浮点模型\n", "\n", "配置浮点模型:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "#from torchvision.models.quantization import mobilenet_v2\n", "\n", "# 定义模型\n", "def create_model(quantize=False,\n", " num_classes=10,\n", " pretrained=False):\n", " float_model = mobilenet_v2(pretrained=pretrained,\n", " quantize=quantize)\n", " # 匹配 ``num_classes``\n", " float_model.classifier[1] = nn.Linear(float_model.last_channel,\n", " num_classes)\n", " return float_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义模型:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "float_model = create_model(pretrained=True,\n", " quantize=False,\n", " num_classes=num_classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义微调的函数 {class}`torchq.xinet.CV`.{func}`train_fine_tuning` 用于模型。\n", "\n", "微调浮点模型:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.012, train acc 0.996, test acc 0.949\n", "276.9 examples/sec on cuda:2\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-29T17:34:51.353391\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "CV.train_fine_tuning(float_model, train_iter, test_iter,\n", " learning_rate=learning_rate,\n", " num_epochs=num_epochs,\n", " device='cuda:2',\n", " param_group=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "保存模型:" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "torch.save(float_model.state_dict(), saved_model_dir + float_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 配置可量化模型\n", "\n", "加载浮点模型:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "float_model = create_model(quantize=False,\n", " num_classes=num_classes)\n", "float_model = load_model(float_model, saved_model_dir + float_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看浮点模型的信息:" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "def print_info(model,\n", " model_type='浮点模型',\n", " test_iter=test_iter,\n", " criterion=criterion, num_eval=num_eval):\n", " '''打印信息'''\n", " print_size_of_model(model)\n", " top1, top5 = evaluate(model, criterion, test_iter)\n", " print(f'\\n{model_type}:\\n\\t'\n", " f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型大小:9.187789 MB\n", "\n", "浮点模型:\n", "\t在 10000 张图片上评估 accuracy 为: 94.91000\n" ] } ], "source": [ "print_info(float_model, model_type='浮点模型')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以先查看融合前的 inverted residual 块:" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): ConvNormActivation(\n", " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n", " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", ")" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "float_model.features[1].conv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "融合模块:" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "float_model.fuse_model(is_qat=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看融合后的 inverted residual 块:" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvReLU2d(\n", " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n", " (1): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))\n", " (2): Identity()\n", ")" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "float_model.features[1].conv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "为了得到“基线”精度,看看融合模块的非量化模型的精度:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "baseline 模型大小\n", "模型大小:8.923757 MB\n" ] } ], "source": [ "model_type = '融合后的浮点模型'\n", "print(\"baseline 模型大小\")\n", "print_size_of_model(float_model)\n", "\n", "top1, top5 = evaluate(float_model, criterion, test_iter)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "融合后的浮点模型:\n", "\t在 10000 张图片上评估 accuracy 为: 94.91\n" ] } ], "source": [ "from torch import jit\n", "print(f'\\n{model_type}:\\n\\t在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.2f}')\n", "# 保存\n", "jit.save(jit.script(float_model), saved_model_dir + scripted_float_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这将是我们进行比较的基准。接下来,尝试不同的量化方法。\n", "\n", "### PTQ 实战" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "# 加载模型\n", "myModel = create_model(pretrained=False,\n", " quantize=False,\n", " num_classes=num_classes)\n", "float_model = load_model(myModel,\n", " saved_model_dir + float_model_file)\n", "myModel.eval()\n", "\n", "# 融合\n", "myModel.fuse_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "指定量化配置(从简单的最小/最大范围估计和加权的逐张量量化开始):" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QConfig(activation=functools.partial(, quant_min=0, quant_max=127){}, weight=functools.partial(, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torch.ao.quantization.qconfig import default_qconfig\n", "\n", "myModel.qconfig = default_qconfig\n", "myModel.qconfig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "开始校准准备:" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PTQ 准备:插入观测者\n", "\n", " 查看观测者插入后的 inverted residual \n", "\n", " Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvReLU2d(\n", " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n", " (1): ReLU()\n", " (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): Conv2d(\n", " 32, 16, kernel_size=(1, 1), stride=(1, 1)\n", " (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)\n", " )\n", " (2): Identity()\n", ")\n" ] } ], "source": [ "from torch.ao.quantization.quantize import prepare\n", "\n", "print('PTQ 准备:插入观测者')\n", "prepare(myModel, inplace=True)\n", "print('\\n 查看观测者插入后的 inverted residual \\n\\n',\n", " myModel.features[1].conv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用数据集校准:" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "PTQ:校准完成!\n" ] } ], "source": [ "num_calibration_batches = 200 # 取部分训练集做校准\n", "evaluate(myModel, criterion, train_iter, neval_batches=num_calibration_batches)\n", "print('\\nPTQ:校准完成!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "转换为量化模型:" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PTQ:转换完成!\n" ] } ], "source": [ "from torch.ao.quantization.quantize import convert\n", "\n", "convert(myModel, inplace=True)\n", "print('PTQ:转换完成!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "融合并量化后,查看融合模块的 Inverted Residual 块:" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): ConvNormActivation(\n", " (0): QuantizedConvReLU2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.1370350867509842, zero_point=0, padding=(1, 1), groups=32)\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.20581312477588654, zero_point=69)\n", " (2): Identity()\n", ")" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "myModel.features[1].conv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "量化后的模型大小:" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型大小:2.356113 MB\n" ] } ], "source": [ "print_size_of_model(myModel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "评估:" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "PTQ 模型:\n", "\t在 10000 张图片上评估 accuracy 为: 51.11\n" ] } ], "source": [ "model_type = 'PTQ 模型'\n", "top1, top5 = evaluate(myModel, criterion, test_iter)\n", "print(f'\\n{model_type}:\\n\\t在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.2f}')\n", "# jit.save(jit.script(myModel), saved_model_dir + scripted_ptq_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用了简单的 min/max 观测器来确定量化参数,将模型的大小减少到了 2.36 MB 以下,几乎减少了 4 倍。\n", "\n", "此外,通过使用不同的量化配置来显著提高精度(对于量化 ARM 架构的推荐配置重复同样的练习)。该配置的操作如下:\n", "\n", "- 在 per-channel 基础上量化权重\n", "- 使用直方图观测器,收集激活的直方图,然后以最佳方式选择量化参数。" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QConfig(activation=functools.partial(, reduce_range=True){}, weight=functools.partial(, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "per_channel_quantized_model = create_model(quantize=False,\n", " num_classes=num_classes)\n", "per_channel_quantized_model = load_model(per_channel_quantized_model,\n", " saved_model_dir + float_model_file)\n", "per_channel_quantized_model.eval()\n", "per_channel_quantized_model.fuse_model()\n", "per_channel_quantized_model.qconfig = get_default_qconfig('fbgemm')\n", "per_channel_quantized_model.qconfig" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "num_calibration_batches = 200 # 仅仅取 200 个批次\n", "prepare(per_channel_quantized_model, inplace=True)\n", "evaluate(per_channel_quantized_model, criterion,\n", " train_iter, num_calibration_batches)\n", "\n", "model_type = 'PTQ 模型(直方图观测器)'\n", "convert(per_channel_quantized_model, inplace=True)\n", "top1, top5 = evaluate(per_channel_quantized_model, criterion, test_iter)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "PTQ 模型(直方图观测器):\n", "\t在 10000 张图片上评估 accuracy 为: 80.42\n" ] } ], "source": [ "print(f'\\n{model_type}:\\n\\t在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.2f}')\n", "jit.save(jit.script(per_channel_quantized_model),\n", " saved_model_dir + scripted_quantized_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "仅仅改变这种量化配置方法,就可以将准确度提高到 $80.42\\%$ 以上!尽管如此,这还是比 $95\\%$ 的基线水平低了 $15\\%$。\n", "\n", "### QAT 实战\n", "\n", "使用 QAT,所有的权值和激活都在前向和后向训练过程中被“伪量化”:也就是说,浮点值被舍入以模拟 int8 值,但所有的计算仍然使用浮点数完成。因此,训练过程中的所有权重调整都是在“感知到”模型最终将被量化的情况下进行的;因此,在量化之后,这种方法通常比动态量化或训练后的静态量化产生更高的精度。\n", "\n", "实际执行 QAT 的总体工作流程与之前非常相似:\n", "\n", "- 可以使用与以前相同的模型:不需要为量化感知训练做额外的准备。\n", "- 需要使用 `qconfig` 来指定在权重和激活之后插入何种类型的伪量化,而不是指定观测者。" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "def create_qat_model(num_classes,\n", " model_path,\n", " quantize=False,\n", " backend='fbgemm'):\n", " qat_model = create_model(quantize=quantize,\n", " num_classes=num_classes)\n", " qat_model = load_model(qat_model, model_path)\n", " qat_model.fuse_model()\n", " qat_model.qconfig = get_default_qat_qconfig(backend=backend)\n", " return qat_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后,`prepare_qat` 执行“伪量化”,为量化感知训练准备模型:" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "from torch.ao.quantization.quantize import prepare_qat\n", "\n", "model_path = saved_model_dir + float_model_file\n", "qat_model = create_qat_model(num_classes, model_path)\n", "qat_model = prepare_qat(qat_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inverted Residual Block:准备好 QAT 后,注意伪量化模块:" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False\n", " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (weight_fake_quant): FusedMovingAvgObsFakeQuantize(\n", " fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False\n", " (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", " )\n", " (activation_post_process): FusedMovingAvgObsFakeQuantize(\n", " fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True\n", " (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)\n", " )\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvBn2d(\n", " 32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False\n", " (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (weight_fake_quant): FusedMovingAvgObsFakeQuantize(\n", " fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False\n", " (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", " )\n", " (activation_post_process): FusedMovingAvgObsFakeQuantize(\n", " fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True\n", " (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)\n", " )\n", " )\n", " (2): Identity()\n", ")" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qat_model.features[1].conv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练具有高精确度的量化模型要求在推理时对数值进行精确的建模。因此,对于量化感知训练,我们对训练循环进行如下修改:\n", "\n", "- 将批处理范数转换为训练结束时的运行均值和方差,以更好地匹配推理数值。\n", "- 冻结量化器参数(尺度和零点)并微调权重。" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.013, train acc 0.996, test acc 0.948\n", "55.3 examples/sec on cuda:2\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-30T04:56:14.107754\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "CV.train_fine_tuning(qat_model,\n", " train_iter,\n", " test_iter,\n", " learning_rate=learning_rate,\n", " num_epochs=30,\n", " device='cuda:2',\n", " param_group=True,\n", " is_freeze=False,\n", " is_quantized_acc=False,\n", " need_qconfig=False,\n", " ylim=[0.8, 1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{note}\n", "这里的损失函数向上平移了 0.8 以提供更好的视觉效果。\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由于量化模型暂仅支持 CPU,故而需要先将模型转换为 CPU 版本,则转为量化版本:" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "convert(qat_model.cpu().eval(), inplace=True)\n", "qat_model.eval();" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型大小:2.656573 MB\n", "\n", "QAT 模型:\n", "\t在 10000 张图片上评估 accuracy 为: 94.41000\n" ] } ], "source": [ "print_info(qat_model,'QAT 模型')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "量化感知训练在整个数据集上的准确率超过 $94.4\\%$,接近浮点精度 $95\\%$。\n", "\n", "更多关于 QAT 的内容:\n", "\n", "- QAT 是后训练量化技术的超集,允许更多的调试。例如,我们可以分析模型的准确性是否受到权重或激活量化的限制。\n", "- 也可以在浮点上模拟量化模型的准确性,因为使用伪量化来模拟实际量化算法的数值。\n", "- 也可以很容易地模拟训练后量化。\n", "\n", "保存 QAT 模型:" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "jit.save(jit.script(qat_model), saved_model_dir + scripted_qat_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 小结\n", "\n", "同样可以使用 {func}`~torch.ao.quantization.quantize.quantize` 和 {func}`~torch.ao.quantization.quantize.quantize_qat` 简化流程。\n", "\n", "比如,QAT 流程可以这样:\n", "\n", "```python\n", "model_path = saved_model_dir + float_model_file\n", "qat_model = create_qat_model(num_classes, model_path)\n", "num_epochs = 30\n", "ylim = [0.8, 1]\n", "device = 'cuda:2'\n", "is_freeze = False\n", "is_quantized_acc = False\n", "need_qconfig = True # 做一些 QAT 的量化配置工作\n", "param_group = True\n", "\n", "# 提供位置参数\n", "args = [train_iter,\n", " test_iter,\n", " learning_rate,\n", " num_epochs,\n", " device,\n", " is_freeze,\n", " is_quantized_acc,\n", " need_qconfig,\n", " param_group,\n", " ylim]\n", "\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "简而言之,不管是 PTQ 还是 QAT,我们只需要自定义融合模块函数和量化校准函数(比如 QAT 的训练中校准,PTQ 的训练后校准)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "ccd751c8c176f1a7084878738c6c59984a17d1189ffe2fae146e3d74e2010826" }, "kernelspec": { "display_name": "Python 3.10.4 (conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 2 }