{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 测试 QAT" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/media/pc/data/4tb/xinet/web/pytorch-book/docs\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "import time\n", "import torch.quantization\n", "\n", "# 设置 warnings\n", "import warnings\n", "warnings.filterwarnings(\n", " action='ignore',\n", " category=DeprecationWarning,\n", " module=r'.*'\n", ")\n", "warnings.filterwarnings(\n", " action='default',\n", " module=r'torch.quantization'\n", ")\n", "\n", "# 为可重复的结果指定随机种子\n", "torch.manual_seed(191009)\n", "\n", "from mod import load_mod\n", "load_mod()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "作为最后一个主要的设置步骤,我们为训练和测试集定义了数据加载器。\n", "\n", "要使用整个 ImageNet 数据集运行本教程中的代码,请先按照 [ImageNet Data](http://www.image-net.org/download) 中的说明下载 ImageNet。将下载的文件解压缩到 `data_path` 文件夹中。\n", "\n", "下载完数据后,我们将在下面展示一些函数,[这些函数](https://github.com/pytorch/vision/blob/master/references/detection/train.py)定义了用于读取数据的数据加载器。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from pytorch_book.datasets.imagenet import prepare_data_loaders" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "data_path = '/media/pc/data/4tb/xinet/datasets/imagenet2'\n", "\n", "train_batch_size = 30\n", "eval_batch_size = 50\n", "\n", "data_loader, data_loader_test = prepare_data_loaders(data_path,\n", " train_batch_size,\n", " eval_batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{rubric} 训练模型\n", "```\n", "\n", "写一个通用函数 {func}`~tools.train_model` 来训练模型:\n", "\n", "- 调度学习率\n", "- 保存最佳模型" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'tools'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/media/pc/data/4tb/xinet/web/pytorch-book/docs/quantization/study/test.ipynb Cell 7'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtools\u001b[39;00m \u001b[39mimport\u001b[39;00m train_model\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tools'" ] } ], "source": [ "from tools import train_model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "78526419bf48930935ba7e23437b2460cb231485716b036ebb8701887a294fa8" }, "kernelspec": { "display_name": "Python 3.10.0 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }