{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# QAT 的不同训练策略\n", "\n", "载入库:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging\n", "from collections import namedtuple\n", "import torch\n", "from torch import nn, jit\n", "from torch.ao.quantization import quantize_qat\n", "from torchvision.models.quantization import mobilenet_v2\n", "\n", "\n", "def create_model(num_classes=10,\n", " quantize=False,\n", " pretrained=False):\n", " '''定义模型'''\n", " float_model = mobilenet_v2(pretrained=pretrained,\n", " quantize=quantize)\n", " # 匹配 ``num_classes``\n", " float_model.classifier[1] = nn.Linear(float_model.last_channel,\n", " num_classes)\n", " return float_model\n", "\n", "\n", "def create_float_model(num_classes,\n", " model_path):\n", " model = create_model(quantize=False,\n", " num_classes=num_classes)\n", " model = load_model(model, model_path)\n", " return model\n", "\n", "def set_cudnn(cuda_path=\":/usr/local/cuda/bin\",\n", " LD_LIBRARY_PATH=\"/usr/local/cuda/lib64\"):\n", " import os\n", " os.environ[\"PATH\"] += cuda_path\n", " os.environ[\"LD_LIBRARY_PATH\"] = LD_LIBRARY_PATH" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "|===========================================================================|\n", "| PyTorch CUDA memory summary, device ID 0 |\n", "|---------------------------------------------------------------------------|\n", "| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n", "|===========================================================================|\n", "| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", "|---------------------------------------------------------------------------|\n", "| Allocated memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Active memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |\n", "| from large pool | 0 B | 0 B | 0 B | 0 B |\n", "| from small pool | 0 B | 0 B | 0 B | 0 B |\n", "|---------------------------------------------------------------------------|\n", "| Allocations | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Active allocs | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| GPU reserved segments | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Non-releasable allocs | 0 | 0 | 0 | 0 |\n", "| from large pool | 0 | 0 | 0 | 0 |\n", "| from small pool | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Oversize allocations | 0 | 0 | 0 | 0 |\n", "|---------------------------------------------------------------------------|\n", "| Oversize GPU segments | 0 | 0 | 0 | 0 |\n", "|===========================================================================|\n", "\n" ] } ], "source": [ "torch.cuda.empty_cache() # 清空 GPU 缓存\n", "print(torch.cuda.memory_summary()) # 打印显存\n", "set_cudnn()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 设置 warnings\n", "import warnings\n", "warnings.filterwarnings(\n", " action='ignore',\n", " category=DeprecationWarning,\n", " module='.*'\n", ")\n", "warnings.filterwarnings(\n", " action='ignore',\n", " module='torch.ao.quantization'\n", ")\n", "# 载入自定义模块\n", "from mod import torchq\n", "\n", "from torchq.helper import evaluate, print_size_of_model, load_model\n", "\n", "def print_info(model, model_type, num_eval, criterion):\n", " '''打印信息'''\n", " print_size_of_model(model)\n", " top1, top5 = evaluate(model, criterion, test_iter)\n", " print(f'\\n{model_type}:\\n\\t'\n", " f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "Config = namedtuple('Config',\n", " [\"net\",\n", " \"device\",\n", " \"train_iter\",\n", " \"test_iter\",\n", " \"loss\",\n", " \"trainer\",\n", " \"num_epochs\",\n", " \"logger\",\n", " \"need_qconfig\",\n", " \"is_freeze\",\n", " \"is_quantized_acc\",\n", " \"backend\",\n", " \"ylim\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "超参数设置:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "saved_model_dir = 'models/'\n", "model_name = \"mobilenet\"\n", "logfile = f\"outputs/{model_name}.log\"\n", "float_model_file = f'{model_name}_pretrained_float.pth'\n", "logging.basicConfig(filename=logfile, filemode='w')\n", "logger = logging.getLogger(name=f\"{model_name}Logger\")\n", "logger.setLevel(logging.DEBUG)\n", "# scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'\n", "# 超参数\n", "float_model_path = saved_model_dir + float_model_file\n", "batch_size = 8\n", "num_classes = 10\n", "num_epochs = 50\n", "learning_rate = 5e-5\n", "ylim = [0.8, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "加载数据集:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "from torchq.xinet import CV\n", "\n", "# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224\n", "train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,\n", " resize=224)\n", "num_eval = sum(len(ys) for _, ys in test_iter)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "打印浮点模型信息:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型大小:9.187789 MB\n", "Batch 0 ~ Acc@1 100.00 (100.00)\t Acc@5 100.00 (100.00)\n", "Batch 500 ~ Acc@1 100.00 ( 95.08)\t Acc@5 100.00 ( 99.93)\n", "Batch 1000 ~ Acc@1 100.00 ( 94.84)\t Acc@5 100.00 ( 99.91)\n", "\n", "浮点模型:\n", "\t在 10000 张图片上评估 accuracy 为: 94.91000\n" ] } ], "source": [ "float_model = create_float_model(num_classes, float_model_path)\n", "model_type = '浮点模型'\n", "criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", "print_info(float_model, model_type, num_eval, criterion)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "普通策略:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Config.__new__() missing 2 required positional arguments: 'backend' and 'ylim'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb Cell 13'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m need_qconfig \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m \u001b[39m# 做一些 QAT 的量化配置工作\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[39m# 提供位置参数\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m config \u001b[39m=\u001b[39m Config(train_iter,\n\u001b[1;32m 13\u001b[0m test_iter,\n\u001b[1;32m 14\u001b[0m learning_rate,\n\u001b[1;32m 15\u001b[0m num_epochs,\n\u001b[1;32m 16\u001b[0m logger,\n\u001b[1;32m 17\u001b[0m device,\n\u001b[1;32m 18\u001b[0m is_freeze,\n\u001b[1;32m 19\u001b[0m is_quantized_acc,\n\u001b[1;32m 20\u001b[0m need_qconfig,\n\u001b[1;32m 21\u001b[0m param_group,\n\u001b[1;32m 22\u001b[0m ylim)\n", "\u001b[0;31mTypeError\u001b[0m: Config.__new__() missing 2 required positional arguments: 'backend' and 'ylim'" ] }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." ] } ], "source": [ "num_epochs = 30\n", "ylim = [0.85, 1]\n", "device = 'cuda:1'\n", "param_group = True\n", "\n", "# 量化参数\n", "is_freeze = False\n", "is_quantized_acc = False\n", "need_qconfig = True # 做一些 QAT 的量化配置工作\n", "\n", "# 提供位置参数\n", "config = Config(train_iter,\n", " test_iter,\n", " learning_rate,\n", " num_epochs,\n", " logger,\n", " device,\n", " is_freeze,\n", " is_quantized_acc,\n", " need_qconfig,\n", " param_group,\n", " ylim)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\"net\",\n", " \"device\",\n", " \"train_iter\",\n", " \"test_iter\",\n", " \"loss\",\n", " \"trainer\",\n", " \"num_epochs\",\n", " \"logger\",\n", " \"need_qconfig\",\n", " \"is_freeze\",\n", " \"is_quantized_acc\",\n", " \"backend\",\n", " \"ylim\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args = [train_iter,\n", " test_iter,\n", " learning_rate,\n", " num_epochs,\n", " device,\n", " is_freeze,\n", " is_quantized_acc,\n", " need_qconfig,\n", " param_group,\n", " ylim]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, config)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "qat_model = create_float_model(num_classes, float_model_path)\n", "qat_model.fuse_model() # 添加融合\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "冻结前几次训练的量化器以及观测器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args[5] = True\n", "args[6] = False\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "输出量化精度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args[6] = True\n", "args[5] = False\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "冻结前几次训练的观测器并且生成量化精度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args[5] = True\n", "args[6] = True\n", "qat_model = create_float_model(num_classes, float_model_path)\n", "quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.nn.quantized.FloatFunctional" ] } ], "metadata": { "interpreter": { "hash": "ccd751c8c176f1a7084878738c6c59984a17d1189ffe2fae146e3d74e2010826" }, "kernelspec": { "display_name": "Python 3.10.4 (conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }