{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 量化菜谱\n", "\n", "{guilabel}`参考`:[量化菜谱](https://pytorch.org/tutorials/recipes/quantization.html)\n", "\n", "量化是一种将模型参数中的 32 位浮点数转换为 8 位整数的技术。通过量化,模型大小和内存占用可以减少到原来大小的 1/4,推理速度可以提高 2-4 倍,而精度保持不变。\n", "\n", "量化模型的方法或工作流程大致有三种:后训练动态量化(post training dynamic quantization)、后训练静态量化(post training dynamic quantization)和量化感知训练(quantization aware training)。但是,如果您想要使用的模型已经有量化版本,您可以直接使用它,而不需要经过上面三个工作流中的任何一个。例如,`torchvision` 库已经包括了模型 MobileNet v2、ResNet 18、ResNet 50、Inception v3、GoogleNet 等的量化版本。因此,我们将最后一种方法作为另一个工作流,尽管很简单。\n", "\n", "```{note}\n", "量化支持可用于[有限的一组算子](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#device-and-operator-support)。\n", "```\n", "\n", "PyTorch 支持四种量化工作流。\n", "\n", "## 使用预先训练的量化网络\n", "\n", "要得到网络的量化模型,只需做:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from torchvision.models.quantization import resnet18 as qresnet18\n", "\n", "model_quantized = qresnet18(pretrained=True,\n", " quantize=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "为了比较非量化的模型与其量化版本的大小差异:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "原始模型 46.83 MB\n", "量化模型 11.84 MB\n" ] } ], "source": [ "from torchvision.models import resnet18\n", "\n", "model = resnet18(pretrained=True)\n", "\n", "import os\n", "import torch\n", "\n", "def print_model_size(mdl, name=''):\n", " torch.save(mdl.state_dict(), \"tmp.pt\")\n", " size = os.path.getsize(\"tmp.pt\")/1e6\n", " print(f'{name} {size: .2f} MB')\n", " os.remove('tmp.pt')\n", "\n", "print_model_size(model, '原始模型')\n", "print_model_size(model_quantized, '量化模型')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 后训练动态量化\n", "\n", "要应用动态量化,它将模型中的所有权重从 32 位的浮点数转换为 8 位的整数,但在对这些激活执行计算之前不会将激活转换为 int8,只需调用 {func}`~torch.ao.quantization.quantize.quantize_dynamic`:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torch.quantization.quantize import quantize_dynamic\n", "\n", "model_dynamic_quantized = quantize_dynamic(\n", " model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中 `qconfig_spec` 指定要应用量化的 `model` 中的子模块名称列表。\n", "\n", "```{warning}\n", "动态量化的一个重要限制是,虽然它是最简单的工作流程,如果你没有预先训练的量化模型准备使用,它目前在 `qconfig_spec` 中只支持 `nn.Linear` 和 `nn.LSTM`,这意味着你将不得不使用静态量化或量化感知训练,稍后讨论,以量化其他模块,如 `nn.Conv2d`。\n", "```\n", "\n", "使用训练后动态量化的另外三个例子是 [Bert 例子](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html)、[LSTM 模型例子](https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html#test-dynamic-quantization) 和另一个 [LSTM demo 例子](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html#do-the-quantization)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 后训练静态量化\n", "\n", "该方法将权值和激活提前转换为 8 位整数,这样就不会像动态量化那样在推理过程中对激活进行实时转换,从而显著提高了性能。\n", "\n", "要在模型上应用静态量化,运行以下代码:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:177: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n", " warnings.warn(\n", "/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1124: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point \n", " warnings.warn(\n" ] } ], "source": [ "from torch.quantization.quantize import prepare, convert\n", "from torch.quantization.qconfig import get_default_qconfig\n", "from torchvision.models.quantization import resnet18 as qresnet18\n", "\n", "backend = \"fbgemm\" # 若为 x86,否则为 'qnnpack' \n", "model.qconfig = get_default_qconfig(backend)\n", "torch.backends.quantized.engine = backend\n", "model_static_quantized = prepare(model, inplace=False)\n", "model_static_quantized = convert(model_static_quantized, inplace=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "显示静态量化模型大小:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 11.93 MB\n" ] } ], "source": [ "print_model_size(model_static_quantized)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[这里](https://pytorch.org/docs/stable/quantization.html#quantization-api-summary)有一个完整的模型定义和静态量化的例子。[这里](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)有一个专门的静态量化教程。\n", "\n", "```{note}\n", "为了使模型运行在通常有 ARM 架构的移动设备上,你需要使用 `'qnnpack'` 作为后台;要在 x86 架构的计算机上运行模型,请使用 `'fbgemm'`。\n", "```\n", "\n", "## 量化感知训练(QAT)\n", "\n", "量化感知训练(Quantization aware training)在模型训练过程中对所有的权值和激活量都插入伪量化,比训练后的量化方法具有更高的推理精度。它通常用于 CNN 的模型中。\n", "\n", "要启用量化感知训练的模型,请在模型定义的 `__init__` 方法中定义 `QuantStub` 和 `DeQuantStub`,以将张量从浮点类型转换为量化类型,反之亦然。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from torch import nn\n", "from torch.quantization.stubs import QuantStub, DeQuantStub\n", "\n", "class MyModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.quant = QuantStub()\n", " self.dequant = DeQuantStub()\n", "\n", " def forward(self, x):\n", " x = self.quant(x)\n", " # 其他运算\n", " x = self.dequant(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后在模型定义的前向方法的开头和结尾分别调用 `x = self.quant(x)` 和 `x = self.dequant(x)`。\n", "\n", "要进行量化感知训练,请使用下面的代码片段:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/utils.py:210: UserWarning: must run observer before calling calculate_qparams. Returning default values.\n", " warnings.warn(\n" ] } ], "source": [ "from torch.quantization.quantize import prepare_qat, convert\n", "from torch.quantization.qconfig import get_default_qat_qconfig\n", "\n", "model = MyModel()\n", "model.qconfig = get_default_qat_qconfig(backend)\n", "model_qat = prepare_qat(model, inplace=False)\n", "# QAT\n", "model_qat = convert(model_qat.eval(), inplace=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "有关量化感知训练的更多详细示例,请参阅[此处](https://pytorch.org/docs/master/quantization.html#quantization-aware-training)和[此处](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#quantization-aware-training)。\n", "\n", "在使用上面的步骤之一生成量化模型之后,在模型可以在移动设备上运行之前,它需要进一步转换为 TorchScript 格式,然后针对移动端应用程序进行优化。请参阅[脚本和优化移动端食谱](https://pytorch.org/tutorials/recipes/script_optimized.html)的详细信息。" ] } ], "metadata": { "interpreter": { "hash": "7a45eadec1f9f49b0fdfd1bc7d360ac982412448ce738fa321afc640e3212175" }, "kernelspec": { "display_name": "Python 3.10.4 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }