{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# QAT 的不同训练策略\n", "\n", "载入库:" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn, jit\n", "from torch.ao.quantization import quantize_qat\n", "from torchvision.models.quantization import mobilenet_v2\n", "\n", "# 载入自定义模块\n", "from mod import torchq\n", "from torchq.xinet import CV\n", "from torchq.helper import evaluate, print_size_of_model, load_model\n", "\n", "\n", "def create_model(num_classes=10,\n", " quantize=False,\n", " pretrained=False):\n", " '''定义模型'''\n", " float_model = mobilenet_v2(pretrained=pretrained,\n", " quantize=quantize)\n", " # 匹配 ``num_classes``\n", " float_model.classifier[1] = nn.Linear(float_model.last_channel,\n", " num_classes)\n", " return float_model\n", "\n", "\n", "def create_float_model(num_classes,\n", " model_path):\n", " model = create_model(quantize=False,\n", " num_classes=num_classes)\n", " model = load_model(model, model_path)\n", " return model\n" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "|===========================================================================|\n", "| PyTorch CUDA memory summary, device ID 0 |\n", "|---------------------------------------------------------------------------|\n", "| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n", "|===========================================================================|\n", "| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", "|---------------------------------------------------------------------------|\n", "| Allocated memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Active memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Allocations | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Active allocs | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| GPU reserved segments | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Non-releasable allocs | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Oversize allocations | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Oversize GPU segments | 0 | 0 | 0 | 0 |\n", "|===========================================================================|\n", "\n" ] } ], "source": [ "torch.cuda.empty_cache() # 清空 GPU 缓存\n", "print(torch.cuda.memory_summary()) # 打印显存" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "# 设置 warnings\n", "import warnings\n", "warnings.filterwarnings(\n", " action='ignore',\n", " category=DeprecationWarning,\n", " module='.*'\n", ")\n", "warnings.filterwarnings(\n", " action='ignore',\n", " module='torch.ao.quantization'\n", ")\n", "\n", "\n", "def print_info(model, model_type, num_eval, criterion):\n", " '''打印信息'''\n", " print_size_of_model(model)\n", " top1, top5 = evaluate(model, criterion, test_iter)\n", " print(f'\\n{model_type}:\\n\\t'\n", " f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "超参数设置:" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "saved_model_dir = 'models/'\n", "float_model_file = 'mobilenet_pretrained_float.pth'\n", "# scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'\n", "# 超参数\n", "float_model_path = saved_model_dir + float_model_file\n", "batch_size = 16\n", "num_classes = 10\n", "num_epochs = 50\n", "learning_rate = 5e-5\n", "ylim = [0.8, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "加载数据集:" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224\n", "train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,\n", " resize=224)\n", "num_eval = sum(len(ys) for _, ys in test_iter)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "打印浮点模型信息:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型大小:9.187789 MB\n", "\n", "浮点模型:\n", "\t在 10000 张图片上评估 accuracy 为: 94.91000\n" ] } ], "source": [ "float_model = create_float_model(num_classes, float_model_path)\n", "model_type = '浮点模型'\n", "criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", "print_info(float_model, model_type, num_eval, criterion)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "普通策略:" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.013, train acc 0.996, test acc 0.950\n", "352.4 examples/sec on cuda:1\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-05-07T18:16:24.860564\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." ] } ], "source": [ "num_epochs = 1\n", "ylim = [0.85, 1]\n", "device = 'cuda:1'\n", "param_group = True\n", "\n", "# 量化参数\n", "is_freeze = False\n", "is_quantized_acc = False\n", "need_qconfig = True # 做一些 QAT 的量化配置工作\n", "\n", "# 提供位置参数\n", "args = [train_iter,\n", " test_iter,\n", " learning_rate,\n", " num_epochs,\n", " device,\n", " is_freeze,\n", " is_quantized_acc,\n", " need_qconfig,\n", " param_group,\n", " ylim]\n", "\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "qat_model.fuse_model(is_qat=True)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n", "torch.float32\n" ] } ], "source": [ "for name, param in quantized_model.features.named_parameters():\n", " print(param.dtype)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QuantizableMobileNetV2(\n", " (features): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n", " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvBn2d(\n", " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (2): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)\n", " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (3): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)\n", " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (4): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)\n", " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (5): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n", " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (6): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n", " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (7): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)\n", " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (8): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (9): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (10): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (11): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (12): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)\n", " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (13): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)\n", " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (14): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)\n", " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (15): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n", " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (16): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n", " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (17): QuantizableInvertedResidual(\n", " (conv): Sequential(\n", " (0): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (1): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n", " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " (2): ConvBn2d(\n", " (0): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): Identity()\n", " )\n", " (skip_add): FloatFunctional(\n", " (activation_post_process): Identity()\n", " )\n", " )\n", " (18): ConvNormActivation(\n", " (0): ConvBnReLU2d(\n", " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): Identity()\n", " (2): Identity()\n", " )\n", " )\n", " (classifier): Sequential(\n", " (0): Dropout(p=0.2, inplace=False)\n", " (1): Linear(in_features=1280, out_features=10, bias=True)\n", " )\n", " (quant): QuantStub()\n", " (dequant): DeQuantStub()\n", ")" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "quantized_model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "qmodel_file = 'mobilenet_quantization_quantized.pth'\n", "torch.save(quantized_model.state_dict(), saved_model_dir + qmodel_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "冻结前几次训练的量化器以及观测器:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.004, train acc 0.999, test acc 0.951\n", "237.7 examples/sec on cuda:2\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-30T16:36:13.226117\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "args[5] = True\n", "args[6] = False\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "输出量化精度:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.004, train acc 0.999, test acc 0.953\n", "239.0 examples/sec on cuda:2\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-30T19:35:47.050989\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "args[6] = True\n", "args[5] = False\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "冻结前几次训练的观测器并且生成量化精度:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.004, train acc 0.999, test acc 0.951\n", "238.1 examples/sec on cuda:2\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-30T22:31:41.566133\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "args[5] = True\n", "args[6] = True\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "ccd751c8c176f1a7084878738c6c59984a17d1189ffe2fae146e3d74e2010826" }, "kernelspec": { "display_name": "Python 3.10.4 (conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }