{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# QAT\n", "\n", "载入库:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from torch import nn, jit\n", "from torch.ao.quantization.quantize import convert\n", "from torchvision.models.quantization import mobilenet_v2\n", "\n", "\n", "def create_model(num_classes=10,\n", " quantize=False,\n", " pretrained=False):\n", " '''定义模型'''\n", " float_model = mobilenet_v2(pretrained=pretrained,\n", " quantize=quantize)\n", " # 匹配 ``num_classes``\n", " float_model.classifier[1] = nn.Linear(float_model.last_channel,\n", " num_classes)\n", " return float_model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "|===========================================================================|\n", "| PyTorch CUDA memory summary, device ID 0 |\n", "|---------------------------------------------------------------------------|\n", "| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n", "|===========================================================================|\n", "| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", "|---------------------------------------------------------------------------|\n", "| Allocated memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Active memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Allocations | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Active allocs | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| GPU reserved segments | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Non-releasable allocs | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Oversize allocations | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Oversize GPU segments | 0 | 0 | 0 | 0 |\n", "|===========================================================================|\n", "\n" ] } ], "source": [ "import torch\n", "torch.cuda.empty_cache() # 清空 GPU 缓存\n", "print(torch.cuda.memory_summary()) # 打印显存" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 设置 warnings\n", "import warnings\n", "warnings.filterwarnings(\n", " action='ignore',\n", " category=DeprecationWarning,\n", " module='.*'\n", ")\n", "warnings.filterwarnings(\n", " action='ignore',\n", " module='torch.ao.quantization'\n", ")\n", "# 载入自定义模块\n", "from mod import load_mod\n", "load_mod()\n", "\n", "from helper import evaluate, print_size_of_model, load_model\n", "\n", "def print_info(model, model_type, num_eval, criterion):\n", " '''打印信息'''\n", " print_size_of_model(model)\n", " top1, top5 = evaluate(model, criterion, test_iter)\n", " print(f'\\n{model_type}:\\n\\t'\n", " f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')\n", "\n", "\n", "def create_qat_model(num_classes,\n", " model_path,\n", " quantize=False):\n", " qat_model = create_model(quantize=quantize,\n", " num_classes=num_classes)\n", " qat_model = load_model(qat_model, model_path)\n", " return qat_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "超参数设置:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "saved_model_dir = 'models/'\n", "float_model_file = 'mobilenet_pretrained_float.pth'\n", "scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'\n", "# 超参数\n", "float_model_path = saved_model_dir + float_model_file\n", "batch_size = 16\n", "num_classes = 10\n", "device = 'cuda:0'\n", "num_epochs = 50\n", "learning_rate = 5e-5\n", "ylim = [0.6, 1]\n", "\n", "# 设置评估策略\n", "criterion = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "加载数据集:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "from xinet import CV\n", "\n", "# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224\n", "train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,\n", " resize=224)\n", "num_eval = sum(len(ys) for _, ys in test_iter)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "普通策略:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.169, train acc 0.946, test acc 0.912\n", "57.0 examples/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-26T07:35:50.076837\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "is_quantized_acc = False\n", "is_freeze = False\n", "param_group = True\n", "qat_model = create_qat_model(num_classes, float_model_path)\n", "CV.train_fine_tuning(qat_model, train_iter, test_iter,\n", " learning_rate=learning_rate,\n", " num_epochs=num_epochs,\n", " device=device,\n", " param_group=param_group,\n", " is_freeze=is_freeze,\n", " is_quantized_acc=is_quantized_acc,\n", " need_prepare=True,\n", " ylim=ylim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "冻结前几次训练的量化器以及观测器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "is_quantized_acc = False\n", "is_freeze = True\n", "param_group = True\n", "qat_model = create_qat_model(num_classes, float_model_path)\n", "CV.train_fine_tuning(qat_model, train_iter, test_iter,\n", " learning_rate=learning_rate,\n", " num_epochs=num_epochs,\n", " device=device,\n", " param_group=param_group,\n", " is_freeze=is_freeze,\n", " is_quantized_acc=is_quantized_acc,\n", " need_prepare=True,\n", " ylim=ylim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "输出量化精度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "is_quantized_acc = True\n", "is_freeze = False\n", "param_group = True\n", "model_path = saved_model_dir + float_model_file\n", "qat_model = create_qat_model(num_classes, model_path)\n", "CV.train_fine_tuning(qat_model, train_iter, test_iter,\n", " learning_rate=learning_rate,\n", " num_epochs=num_epochs,\n", " device=device,\n", " param_group=param_group,\n", " is_freeze=is_freeze,\n", " is_quantized_acc=is_quantized_acc,\n", " need_prepare=True,\n", " ylim=ylim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "冻结前几次训练的观测器并且生成量化精度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "is_quantized_acc = True\n", "is_freeze = True\n", "param_group = True\n", "qat_model = create_qat_model(num_classes, float_model_path)\n", "CV.train_fine_tuning(qat_model, train_iter, test_iter,\n", " learning_rate=learning_rate,\n", " num_epochs=num_epochs,\n", " device=device,\n", " param_group=param_group,\n", " is_freeze=is_freeze,\n", " is_quantized_acc=is_quantized_acc,\n", " need_prepare=True,\n", " ylim=ylim)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "78526419bf48930935ba7e23437b2460cb231485716b036ebb8701887a294fa8" }, "kernelspec": { "display_name": "Python 3.10.0 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }