{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 快速入门\n", "\n", "参考:\n", "\n", "1. [量化实践](https://pytorch.org/blog/quantization-in-practice/)\n", "2. [fx graph 模式 POST TRAINING STATIC QUANTIZATION](https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html)\n", "\n", "本教程介绍基于 {mod}`torch.fx` 在 graph 模式下进行训练后静态量化的步骤。FX Graph 模式量化的优点:可以在模型上完全自动地执行量化,尽管可能需要一些努力使模型与 FX Graph 模式量化兼容(象征性地用 {mod}`torch.fx` 跟踪),将有单独的教程来展示如何使我们想量化的模型的一部分与 FX Graph 模式量化兼容。也有 [FX Graph 模式后训练动态量化](https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic.html) 教程。FX Graph 模式 API 如下所示:\n", "\n", "\n", "```python\n", "import torch\n", "from torch.quantization import get_default_qconfig\n", "# Note that this is temporary, \n", "# we'll expose these functions to torch.quantization after official releasee\n", "from torch.quantization.quantize_fx import prepare_fx, convert_fx\n", "float_model.eval()\n", "qconfig = get_default_qconfig(\"fbgemm\")\n", "qconfig_dict = {\"\": qconfig}\n", "def calibrate(model, data_loader):\n", " model.eval()\n", " with torch.no_grad():\n", " for image, target in data_loader:\n", " model(image)\n", "prepared_model = prepare_fx(float_model, qconfig_dict) # fuse modules and insert observers\n", "calibrate(prepared_model, valset) # run calibration on sample data\n", "quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model\n", "```\n", "\n", "## FX Graph 模式量化的动机\n", "\n", "目前 PyTorch 存在 eager 模式量化:[Static Quantization with Eager Mode in PyTorch](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)。\n", "\n", "可以看到,该过程涉及到多个手动步骤,包括:\n", "\n", "- 显式地 quantize 和 dequantize activations,当浮点和量化运算混合在模型中时,这是非常耗时的。\n", "- 显式融合模块,这需要手动识别卷积序列、 batch norms 以及 relus 和其他融合模式。\n", "- PyTorch 张量运算需要特殊处理(如 `add`、`concat` 等)。\n", "- 函数式没有 first class 支持(`functional.conv2d` 和 `functional.linear` 不会被量化)\n", "\n", "这些需要的修改大多来自于 Eager 模式量化的潜在限制。Eager 模式在模块级工作,因为它不能检查实际运行的代码(在 `forward` 函数中),量化是通过模块交换实现的,不知道在 Eager 模式下 `forward` 函数中模块是如何使用的。因此,它需要用户手动插入 `QuantStub` 和 `DeQuantStub`,以标记他们想要 quantize 或 dequantize 的点。在图模式中,可以检查在 `forward` 函数中执行的实际代码(例如 `aten` 函数调用),量化是通过模块和 graph 操作实现的。由于图模式对运行的代码具有完全的可见性,能够自动地找出要融合哪些模块,在哪里插入 observer 调用,quantize/dequantize 函数等,能够自动化整个量化过程。\n", "\n", "FX Graph 模式量化的优点是:\n", "\n", "- 简化量化流程,最小化手动步骤\n", "- 开启了进行更高级别优化的可能性,如自动精度选择(automatic precision selection)\n", "\n", "## 定义辅助函数和 Prepare Dataset\n", "\n", "首先进行必要的导入,定义一些辅助函数并准备数据。这些步骤与 PyTorch 中 [使用 Eager 模式的静态量化](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html) 相同。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "要使用整个 ImageNet 数据集运行本教程中的代码,首先按照 [ImageNet Data](http://www.image-net.org/download) 中的说明下载 ImageNet。将下载的文件解压缩到 `data_path` 文件夹中。\n", "\n", "下载 {mod}`torchvision resnet18 模型 ` 并将其重命名为 `models/resnet18_pretrained_float.pth`。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from torch_book.data import ImageNet\n", "\n", "\n", "root = \"/media/pc/data/4tb/lxw/datasets/ILSVRC\"\n", "saved_model_dir = 'models/'\n", "\n", "dataset = ImageNet(root)\n", "trainset = dataset.loader(batch_size=30, split=\"train\")\n", "valset = dataset.loader(batch_size=50, split=\"val\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import copy\n", "from torchvision import models\n", "\n", "model_name = \"resnet18\"\n", "float_model = getattr(models, model_name)(pretrained=True)\n", "float_model.eval()\n", "\n", "# deepcopy the model since we need to keep the original model around\n", "model_to_quantize = copy.deepcopy(float_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 评估模式的模型\n", "\n", "对于训练后量化,需要将模型设置为评估模式。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model_to_quantize.eval();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用 `qconfig_dict` 指定如何量化模型\n", "\n", "```python\n", "qconfig_dict = {\"\" : default_qconfig}\n", "```\n", "\n", "使用与 Eager 模式量化中相同的 `qconfig`, `qconfig` 只是用于激活和权重的 observers 的命名元组。`qconfig_dict` 是具有以下配置的字典:\n", "\n", "```python\n", "qconfig = {\n", " \" : qconfig_global,\n", " \"sub\" : qconfig_sub,\n", " \"sub.fc\" : qconfig_fc,\n", " \"sub.conv\": None\n", "}\n", "qconfig_dict = {\n", " # qconfig? means either a valid qconfig or None\n", " # optional, global config\n", " \"\": qconfig?,\n", " # optional, used for module and function types\n", " # could also be split into module_types and function_types if we prefer\n", " \"object_type\": [\n", " (torch.nn.Conv2d, qconfig?),\n", " (torch.nn.functional.add, qconfig?),\n", " ...,\n", " ],\n", " # optional, used for module names\n", " \"module_name\": [\n", " (\"foo.bar\", qconfig?)\n", " ...,\n", " ],\n", " # optional, matched in order, first match takes precedence\n", " \"module_name_regex\": [\n", " (\"foo.*bar.*conv[0-9]+\", qconfig?)\n", " ...,\n", " ],\n", " # priority (in increasing order): global, object_type, module_name_regex, module_name\n", " # qconfig == None means fusion and quantization should be skipped for anything\n", " # matching the rule\n", "\n", " # **api subject to change**\n", " # optional: specify the path for standalone modules\n", " # These modules are symbolically traced and quantized as one unit\n", " # so that the call to the submodule appears as one call_module\n", " # node in the forward graph of the GraphModule\n", " \"standalone_module_name\": [\n", " \"submodule.standalone\"\n", " ],\n", " \"standalone_module_class\": [\n", " StandaloneModuleClass\n", " ]\n", "}\n", "```\n", "\n", "可以在 [`qconfig` 文件](https://github.com/pytorch/pytorch/blob/master/torch/quantization/qconfig.py) 中找到与 `qconfig` 相关的实用函数:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torch.quantization import get_default_qconfig, quantize_jit\n", "\n", "qconfig = get_default_qconfig(\"fbgemm\")\n", "qconfig_dict = {\"\": qconfig}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 为静态后训练量化模型做准备" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "from torch.quantization.quantize_fx import prepare_fx\n", "\n", "warnings.filterwarnings('ignore')\n", "\n", "prepared_model = prepare_fx(model_to_quantize, qconfig_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`prepare_fx` 将 BatchNorm 模块折叠到前面的 Conv2d 模块中,并在模型中的适当位置插入 observers。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "graph():\n", " %x : torch.Tensor [#users=1] = placeholder[target=x]\n", " %activation_post_process_0 : [#users=1] = call_module[target=activation_post_process_0](args = (%x,), kwargs = {})\n", " %conv1 : [#users=1] = call_module[target=conv1](args = (%activation_post_process_0,), kwargs = {})\n", " %activation_post_process_1 : [#users=1] = call_module[target=activation_post_process_1](args = (%conv1,), kwargs = {})\n", " %maxpool : [#users=1] = call_module[target=maxpool](args = (%activation_post_process_1,), kwargs = {})\n", " %activation_post_process_2 : [#users=2] = call_module[target=activation_post_process_2](args = (%maxpool,), kwargs = {})\n", " %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%activation_post_process_2,), kwargs = {})\n", " %activation_post_process_3 : [#users=1] = call_module[target=activation_post_process_3](args = (%layer1_0_conv1,), kwargs = {})\n", " %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%activation_post_process_3,), kwargs = {})\n", " %activation_post_process_4 : [#users=1] = call_module[target=activation_post_process_4](args = (%layer1_0_conv2,), kwargs = {})\n", " %add : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_4, %activation_post_process_2), kwargs = {})\n", " %layer1_0_relu_1 : [#users=1] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})\n", " %activation_post_process_5 : [#users=2] = call_module[target=activation_post_process_5](args = (%layer1_0_relu_1,), kwargs = {})\n", " %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%activation_post_process_5,), kwargs = {})\n", " %activation_post_process_6 : [#users=1] = call_module[target=activation_post_process_6](args = (%layer1_1_conv1,), kwargs = {})\n", " %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%activation_post_process_6,), kwargs = {})\n", " %activation_post_process_7 : [#users=1] = call_module[target=activation_post_process_7](args = (%layer1_1_conv2,), kwargs = {})\n", " %add_1 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_7, %activation_post_process_5), kwargs = {})\n", " %layer1_1_relu_1 : [#users=1] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})\n", " %activation_post_process_8 : [#users=2] = call_module[target=activation_post_process_8](args = (%layer1_1_relu_1,), kwargs = {})\n", " %layer2_0_conv1 : [#users=1] = call_module[target=layer2.0.conv1](args = (%activation_post_process_8,), kwargs = {})\n", " %activation_post_process_9 : [#users=1] = call_module[target=activation_post_process_9](args = (%layer2_0_conv1,), kwargs = {})\n", " %layer2_0_conv2 : [#users=1] = call_module[target=layer2.0.conv2](args = (%activation_post_process_9,), kwargs = {})\n", " %activation_post_process_10 : [#users=1] = call_module[target=activation_post_process_10](args = (%layer2_0_conv2,), kwargs = {})\n", " %layer2_0_downsample_0 : [#users=1] = call_module[target=layer2.0.downsample.0](args = (%activation_post_process_8,), kwargs = {})\n", " %activation_post_process_11 : [#users=1] = call_module[target=activation_post_process_11](args = (%layer2_0_downsample_0,), kwargs = {})\n", " %add_2 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_10, %activation_post_process_11), kwargs = {})\n", " %layer2_0_relu_1 : [#users=1] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})\n", " %activation_post_process_12 : [#users=2] = call_module[target=activation_post_process_12](args = (%layer2_0_relu_1,), kwargs = {})\n", " %layer2_1_conv1 : [#users=1] = call_module[target=layer2.1.conv1](args = (%activation_post_process_12,), kwargs = {})\n", " %activation_post_process_13 : [#users=1] = call_module[target=activation_post_process_13](args = (%layer2_1_conv1,), kwargs = {})\n", " %layer2_1_conv2 : [#users=1] = call_module[target=layer2.1.conv2](args = (%activation_post_process_13,), kwargs = {})\n", " %activation_post_process_14 : [#users=1] = call_module[target=activation_post_process_14](args = (%layer2_1_conv2,), kwargs = {})\n", " %add_3 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_14, %activation_post_process_12), kwargs = {})\n", " %layer2_1_relu_1 : [#users=1] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})\n", " %activation_post_process_15 : [#users=2] = call_module[target=activation_post_process_15](args = (%layer2_1_relu_1,), kwargs = {})\n", " %layer3_0_conv1 : [#users=1] = call_module[target=layer3.0.conv1](args = (%activation_post_process_15,), kwargs = {})\n", " %activation_post_process_16 : [#users=1] = call_module[target=activation_post_process_16](args = (%layer3_0_conv1,), kwargs = {})\n", " %layer3_0_conv2 : [#users=1] = call_module[target=layer3.0.conv2](args = (%activation_post_process_16,), kwargs = {})\n", " %activation_post_process_17 : [#users=1] = call_module[target=activation_post_process_17](args = (%layer3_0_conv2,), kwargs = {})\n", " %layer3_0_downsample_0 : [#users=1] = call_module[target=layer3.0.downsample.0](args = (%activation_post_process_15,), kwargs = {})\n", " %activation_post_process_18 : [#users=1] = call_module[target=activation_post_process_18](args = (%layer3_0_downsample_0,), kwargs = {})\n", " %add_4 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_17, %activation_post_process_18), kwargs = {})\n", " %layer3_0_relu_1 : [#users=1] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})\n", " %activation_post_process_19 : [#users=2] = call_module[target=activation_post_process_19](args = (%layer3_0_relu_1,), kwargs = {})\n", " %layer3_1_conv1 : [#users=1] = call_module[target=layer3.1.conv1](args = (%activation_post_process_19,), kwargs = {})\n", " %activation_post_process_20 : [#users=1] = call_module[target=activation_post_process_20](args = (%layer3_1_conv1,), kwargs = {})\n", " %layer3_1_conv2 : [#users=1] = call_module[target=layer3.1.conv2](args = (%activation_post_process_20,), kwargs = {})\n", " %activation_post_process_21 : [#users=1] = call_module[target=activation_post_process_21](args = (%layer3_1_conv2,), kwargs = {})\n", " %add_5 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_21, %activation_post_process_19), kwargs = {})\n", " %layer3_1_relu_1 : [#users=1] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})\n", " %activation_post_process_22 : [#users=2] = call_module[target=activation_post_process_22](args = (%layer3_1_relu_1,), kwargs = {})\n", " %layer4_0_conv1 : [#users=1] = call_module[target=layer4.0.conv1](args = (%activation_post_process_22,), kwargs = {})\n", " %activation_post_process_23 : [#users=1] = call_module[target=activation_post_process_23](args = (%layer4_0_conv1,), kwargs = {})\n", " %layer4_0_conv2 : [#users=1] = call_module[target=layer4.0.conv2](args = (%activation_post_process_23,), kwargs = {})\n", " %activation_post_process_24 : [#users=1] = call_module[target=activation_post_process_24](args = (%layer4_0_conv2,), kwargs = {})\n", " %layer4_0_downsample_0 : [#users=1] = call_module[target=layer4.0.downsample.0](args = (%activation_post_process_22,), kwargs = {})\n", " %activation_post_process_25 : [#users=1] = call_module[target=activation_post_process_25](args = (%layer4_0_downsample_0,), kwargs = {})\n", " %add_6 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_24, %activation_post_process_25), kwargs = {})\n", " %layer4_0_relu_1 : [#users=1] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})\n", " %activation_post_process_26 : [#users=2] = call_module[target=activation_post_process_26](args = (%layer4_0_relu_1,), kwargs = {})\n", " %layer4_1_conv1 : [#users=1] = call_module[target=layer4.1.conv1](args = (%activation_post_process_26,), kwargs = {})\n", " %activation_post_process_27 : [#users=1] = call_module[target=activation_post_process_27](args = (%layer4_1_conv1,), kwargs = {})\n", " %layer4_1_conv2 : [#users=1] = call_module[target=layer4.1.conv2](args = (%activation_post_process_27,), kwargs = {})\n", " %activation_post_process_28 : [#users=1] = call_module[target=activation_post_process_28](args = (%layer4_1_conv2,), kwargs = {})\n", " %add_7 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_28, %activation_post_process_26), kwargs = {})\n", " %layer4_1_relu_1 : [#users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})\n", " %activation_post_process_29 : [#users=1] = call_module[target=activation_post_process_29](args = (%layer4_1_relu_1,), kwargs = {})\n", " %avgpool : [#users=1] = call_module[target=avgpool](args = (%activation_post_process_29,), kwargs = {})\n", " %activation_post_process_30 : [#users=1] = call_module[target=activation_post_process_30](args = (%avgpool,), kwargs = {})\n", " %flatten : [#users=1] = call_function[target=torch.flatten](args = (%activation_post_process_30, 1), kwargs = {})\n", " %activation_post_process_31 : [#users=1] = call_module[target=activation_post_process_31](args = (%flatten,), kwargs = {})\n", " %fc : [#users=1] = call_module[target=fc](args = (%activation_post_process_31,), kwargs = {})\n", " %activation_post_process_32 : [#users=1] = call_module[target=activation_post_process_32](args = (%fc,), kwargs = {})\n", " return activation_post_process_32\n" ] } ], "source": [ "print(prepared_model.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 校准\n", "\n", "将 observers 插入模型后,运行校准函数。校准的目的就是通过一些样本运行代表性的工作负载(例如样本的训练数据集)以便 observers 在模型中能够观测到张量的统计数据,以后使用这些信息来计算量化参数。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "def calibrate(model, data_loader, samples=500):\n", " model.eval()\n", " with torch.no_grad():\n", " k = 0\n", " for image, _ in data_loader:\n", " if k > samples:\n", " break\n", " model(image)\n", " k += len(image)\n", "\n", "calibrate(prepared_model, trainset) # run calibration on sample data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 将模型转换为量化模型\n", "\n", "`convert_fx` 采用校准模型并产生量化模型。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphModule(\n", " (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.03267836198210716, zero_point=0, padding=(3, 3))\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Module(\n", " (0): Module(\n", " (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.019193191081285477, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.051562923938035965, zero_point=75, padding=(1, 1))\n", " )\n", " (1): Module(\n", " (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.019093887880444527, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.06979087740182877, zero_point=78, padding=(1, 1))\n", " )\n", " )\n", " (layer2): Module(\n", " (0): Module(\n", " (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.01557458657771349, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.050476107746362686, zero_point=68, padding=(1, 1))\n", " (downsample): Module(\n", " (0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), scale=0.039443813264369965, zero_point=60)\n", " )\n", " )\n", " (1): Module(\n", " (conv1): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.016193654388189316, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.05214320868253708, zero_point=68, padding=(1, 1))\n", " )\n", " )\n", " (layer3): Module(\n", " (0): Module(\n", " (conv1): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(2, 2), scale=0.018163194879889488, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.05316956341266632, zero_point=51, padding=(1, 1))\n", " (downsample): Module(\n", " (0): QuantizedConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), scale=0.01836947724223137, zero_point=107)\n", " )\n", " )\n", " (1): Module(\n", " (conv1): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.013543782755732536, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.048523254692554474, zero_point=71, padding=(1, 1))\n", " )\n", " )\n", " (layer4): Module(\n", " (0): Module(\n", " (conv1): QuantizedConvReLU2d(256, 512, kernel_size=(3, 3), stride=(2, 2), scale=0.014485283754765987, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.0517863854765892, zero_point=64, padding=(1, 1))\n", " (downsample): Module(\n", " (0): QuantizedConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), scale=0.04331441596150398, zero_point=58)\n", " )\n", " )\n", " (1): Module(\n", " (conv1): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.021167000755667686, zero_point=0, padding=(1, 1))\n", " (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.22766999900341034, zero_point=45, padding=(1, 1))\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): QuantizedLinear(in_features=512, out_features=1000, scale=0.27226582169532776, zero_point=35, qscheme=torch.per_channel_affine)\n", ")\n", "\n", "\n", "\n", "def forward(self, x : torch.Tensor):\n", " conv1_input_scale_0 = self.conv1_input_scale_0\n", " conv1_input_zero_point_0 = self.conv1_input_zero_point_0\n", " quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8); x = conv1_input_scale_0 = conv1_input_zero_point_0 = None\n", " conv1 = self.conv1(quantize_per_tensor); quantize_per_tensor = None\n", " maxpool = self.maxpool(conv1); conv1 = None\n", " layer1_0_conv1 = getattr(self.layer1, \"0\").conv1(maxpool)\n", " layer1_0_conv2 = getattr(self.layer1, \"0\").conv2(layer1_0_conv1); layer1_0_conv1 = None\n", " layer1_0_relu_scale_0 = self.layer1_0_relu_scale_0\n", " layer1_0_relu_zero_point_0 = self.layer1_0_relu_zero_point_0\n", " add_relu = torch.ops.quantized.add_relu(layer1_0_conv2, maxpool, layer1_0_relu_scale_0, layer1_0_relu_zero_point_0); layer1_0_conv2 = maxpool = layer1_0_relu_scale_0 = layer1_0_relu_zero_point_0 = None\n", " layer1_1_conv1 = getattr(self.layer1, \"1\").conv1(add_relu)\n", " layer1_1_conv2 = getattr(self.layer1, \"1\").conv2(layer1_1_conv1); layer1_1_conv1 = None\n", " layer1_1_relu_scale_0 = self.layer1_1_relu_scale_0\n", " layer1_1_relu_zero_point_0 = self.layer1_1_relu_zero_point_0\n", " add_relu_1 = torch.ops.quantized.add_relu(layer1_1_conv2, add_relu, layer1_1_relu_scale_0, layer1_1_relu_zero_point_0); layer1_1_conv2 = add_relu = layer1_1_relu_scale_0 = layer1_1_relu_zero_point_0 = None\n", " layer2_0_conv1 = getattr(self.layer2, \"0\").conv1(add_relu_1)\n", " layer2_0_conv2 = getattr(self.layer2, \"0\").conv2(layer2_0_conv1); layer2_0_conv1 = None\n", " layer2_0_downsample_0 = getattr(getattr(self.layer2, \"0\").downsample, \"0\")(add_relu_1); add_relu_1 = None\n", " layer2_0_relu_scale_0 = self.layer2_0_relu_scale_0\n", " layer2_0_relu_zero_point_0 = self.layer2_0_relu_zero_point_0\n", " add_relu_2 = torch.ops.quantized.add_relu(layer2_0_conv2, layer2_0_downsample_0, layer2_0_relu_scale_0, layer2_0_relu_zero_point_0); layer2_0_conv2 = layer2_0_downsample_0 = layer2_0_relu_scale_0 = layer2_0_relu_zero_point_0 = None\n", " layer2_1_conv1 = getattr(self.layer2, \"1\").conv1(add_relu_2)\n", " layer2_1_conv2 = getattr(self.layer2, \"1\").conv2(layer2_1_conv1); layer2_1_conv1 = None\n", " layer2_1_relu_scale_0 = self.layer2_1_relu_scale_0\n", " layer2_1_relu_zero_point_0 = self.layer2_1_relu_zero_point_0\n", " add_relu_3 = torch.ops.quantized.add_relu(layer2_1_conv2, add_relu_2, layer2_1_relu_scale_0, layer2_1_relu_zero_point_0); layer2_1_conv2 = add_relu_2 = layer2_1_relu_scale_0 = layer2_1_relu_zero_point_0 = None\n", " layer3_0_conv1 = getattr(self.layer3, \"0\").conv1(add_relu_3)\n", " layer3_0_conv2 = getattr(self.layer3, \"0\").conv2(layer3_0_conv1); layer3_0_conv1 = None\n", " layer3_0_downsample_0 = getattr(getattr(self.layer3, \"0\").downsample, \"0\")(add_relu_3); add_relu_3 = None\n", " layer3_0_relu_scale_0 = self.layer3_0_relu_scale_0\n", " layer3_0_relu_zero_point_0 = self.layer3_0_relu_zero_point_0\n", " add_relu_4 = torch.ops.quantized.add_relu(layer3_0_conv2, layer3_0_downsample_0, layer3_0_relu_scale_0, layer3_0_relu_zero_point_0); layer3_0_conv2 = layer3_0_downsample_0 = layer3_0_relu_scale_0 = layer3_0_relu_zero_point_0 = None\n", " layer3_1_conv1 = getattr(self.layer3, \"1\").conv1(add_relu_4)\n", " layer3_1_conv2 = getattr(self.layer3, \"1\").conv2(layer3_1_conv1); layer3_1_conv1 = None\n", " layer3_1_relu_scale_0 = self.layer3_1_relu_scale_0\n", " layer3_1_relu_zero_point_0 = self.layer3_1_relu_zero_point_0\n", " add_relu_5 = torch.ops.quantized.add_relu(layer3_1_conv2, add_relu_4, layer3_1_relu_scale_0, layer3_1_relu_zero_point_0); layer3_1_conv2 = add_relu_4 = layer3_1_relu_scale_0 = layer3_1_relu_zero_point_0 = None\n", " layer4_0_conv1 = getattr(self.layer4, \"0\").conv1(add_relu_5)\n", " layer4_0_conv2 = getattr(self.layer4, \"0\").conv2(layer4_0_conv1); layer4_0_conv1 = None\n", " layer4_0_downsample_0 = getattr(getattr(self.layer4, \"0\").downsample, \"0\")(add_relu_5); add_relu_5 = None\n", " layer4_0_relu_scale_0 = self.layer4_0_relu_scale_0\n", " layer4_0_relu_zero_point_0 = self.layer4_0_relu_zero_point_0\n", " add_relu_6 = torch.ops.quantized.add_relu(layer4_0_conv2, layer4_0_downsample_0, layer4_0_relu_scale_0, layer4_0_relu_zero_point_0); layer4_0_conv2 = layer4_0_downsample_0 = layer4_0_relu_scale_0 = layer4_0_relu_zero_point_0 = None\n", " layer4_1_conv1 = getattr(self.layer4, \"1\").conv1(add_relu_6)\n", " layer4_1_conv2 = getattr(self.layer4, \"1\").conv2(layer4_1_conv1); layer4_1_conv1 = None\n", " layer4_1_relu_scale_0 = self.layer4_1_relu_scale_0\n", " layer4_1_relu_zero_point_0 = self.layer4_1_relu_zero_point_0\n", " add_relu_7 = torch.ops.quantized.add_relu(layer4_1_conv2, add_relu_6, layer4_1_relu_scale_0, layer4_1_relu_zero_point_0); layer4_1_conv2 = add_relu_6 = layer4_1_relu_scale_0 = layer4_1_relu_zero_point_0 = None\n", " avgpool = self.avgpool(add_relu_7); add_relu_7 = None\n", " flatten = torch.flatten(avgpool, 1); avgpool = None\n", " fc = self.fc(flatten); flatten = None\n", " dequantize_14 = fc.dequantize(); fc = None\n", " return dequantize_14\n", " \n" ] } ], "source": [ "from torch.quantization.quantize_fx import convert_fx\n", "\n", "quantized_model = convert_fx(prepared_model)\n", "print(quantized_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 评估\n", "\n", "现在可以打印量化模型的大小和精度。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Size of model before quantization\n", "模型大小(MB):46.873073 MB\n", "Size of model after quantization\n", "模型大小(MB):11.853109 MB\n" ] } ], "source": [ "from torch_book.contrib.helper import evaluate, print_size_of_model\n", "from torch import nn\n", "\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "print(\"Size of model before quantization\")\n", "print_size_of_model(float_model)\n", "print(\"Size of model after quantization\")\n", "print_size_of_model(quantized_model)\n", "top1, top5 = evaluate(quantized_model, criterion, valset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[before serilaization] Evaluation accuracy on test dataset: 69.37, 88.89\n" ] } ], "source": [ "print(f\"[before serilaization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "fx_graph_mode_model_file_path = saved_model_dir + f\"{model_name}_fx_graph_mode_quantized.pth\"\n", "\n", "# this does not run due to some erros loading convrelu module:\n", "# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'\n", "# save the whole model directly\n", "# torch.save(quantized_model, fx_graph_mode_model_file_path)\n", "# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path)\n", "\n", "# save with state_dict\n", "# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)\n", "# import copy\n", "# model_to_quantize = copy.deepcopy(float_model)\n", "# prepared_model = prepare_fx(model_to_quantize, {\"\": qconfig})\n", "# loaded_quantized_model = convert_fx(prepared_model)\n", "# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path))\n", "\n", "# save with script\n", "torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)\n", "loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)\n", "\n", "top1, top5 = evaluate(loaded_quantized_model, criterion, valset)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[after serialization/deserialization] Evaluation accuracy on test dataset: 69.37, 88.89\n" ] } ], "source": [ "print(f\"[after serialization/deserialization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果希望获得更好的精度或性能,请尝试更改 `qconfig_dict`。\n", "\n", "## 调试量化模型\n", "\n", "还可以打印量化的 un-quantized conv 的权重来查看区别,首先显式地调用 `fuse` 来融合模型中的 conv 和 bn:注意,`fuse_fx` 只在 `eval` 模式下工作。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.0007, grad_fn=)\n" ] } ], "source": [ "from torch.quantization.quantize_fx import fuse_fx\n", "\n", "fused = fuse_fx(float_model)\n", "\n", "conv1_weight_after_fuse = fused.conv1[0].weight[0]\n", "conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]\n", "\n", "print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 基线浮点模型和 Eager 模式量化的比较" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Size of baseline model\n", "模型大小(MB):46.874273 MB\n", "Baseline Float Model Evaluation accuracy: 69.76, 89.08\n" ] } ], "source": [ "scripted_float_model_file = \"resnet18_scripted.pth\"\n", "\n", "print(\"Size of baseline model\")\n", "print_size_of_model(float_model)\n", "\n", "top1, top5 = evaluate(float_model, criterion, valset)\n", "print(\"Baseline Float Model Evaluation accuracy: %2.2f, %2.2f\"%(top1.avg, top5.avg))\n", "torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在本节中,将量化模型与 FX Graph 模式的量化模型与在 Eager 模式下量化的模型进行比较。FX Graph 模式和 Eager 模式产生的量化模型非常相似,因此期望精度和 speedup 也很相似。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Size of Fx graph mode quantized model\n", "模型大小(MB):11.855297 MB\n", "FX graph mode quantized model Evaluation accuracy on test dataset: 69.37, 88.89\n", "Size of eager mode quantized model\n", "模型大小(MB):11.850395 MB\n", "eager mode quantized model Evaluation accuracy on test dataset: 69.50, 88.88\n" ] } ], "source": [ "print(\"Size of Fx graph mode quantized model\")\n", "print_size_of_model(quantized_model)\n", "top1, top5 = evaluate(quantized_model, criterion, valset)\n", "print(\"FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f\"%(top1.avg, top5.avg))\n", "\n", "from torchvision.models.quantization.resnet import resnet18\n", "eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()\n", "print(\"Size of eager mode quantized model\")\n", "eager_quantized_model = torch.jit.script(eager_quantized_model)\n", "print_size_of_model(eager_quantized_model)\n", "top1, top5 = evaluate(eager_quantized_model, criterion, valset)\n", "print(\"eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f\"%(top1.avg, top5.avg))\n", "eager_mode_model_file = \"resnet18_eager_mode_quantized.pth\"\n", "torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到 FX Graph 模式和 Eager 模式量化模型的模型大小和精度是非常相似的。\n", "\n", "在 AIBench 中运行模型(单线程)会得到如下结果:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```log\n", "Scripted Float Model:\n", "Self CPU time total: 192.48ms\n", "\n", "Scripted Eager Mode Quantized Model:\n", "Self CPU time total: 50.76ms\n", "\n", "Scripted FX Graph Mode Quantized Model:\n", "Self CPU time total: 50.63ms\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,对于 resnet18, FX Graph 模式和 Eager 模式量化模型都比浮点模型获得了相似的速度,大约比浮点模型快 2-4 倍。但是浮点模型上的实际加速可能会因模型、设备、构建、输入批大小、线程等而不同。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "7a45eadec1f9f49b0fdfd1bc7d360ac982412448ce738fa321afc640e3212175" }, "kernelspec": { "display_name": "Python 3.10.4 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }