测试 QAT#

import torch
import torch.nn as nn

import time
import torch.quantization

# 设置 warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.quantization'
)

# 为可重复的结果指定随机种子
torch.manual_seed(191009)

from mod import load_mod
load_mod()
/media/pc/data/4tb/xinet/web/pytorch-book/docs

作为最后一个主要的设置步骤,我们为训练和测试集定义了数据加载器。

要使用整个 ImageNet 数据集运行本教程中的代码,请先按照 ImageNet Data 中的说明下载 ImageNet。将下载的文件解压缩到 data_path 文件夹中。

下载完数据后,我们将在下面展示一些函数,这些函数定义了用于读取数据的数据加载器。

from pytorch_book.datasets.imagenet import prepare_data_loaders
data_path = '/media/pc/data/4tb/xinet/datasets/imagenet2'

train_batch_size = 30
eval_batch_size = 50

data_loader, data_loader_test = prepare_data_loaders(data_path,
                                                     train_batch_size,
                                                     eval_batch_size)

训练模型

写一个通用函数 train_model() 来训练模型:

  • 调度学习率

  • 保存最佳模型

from tools import train_model
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
/media/pc/data/4tb/xinet/web/pytorch-book/docs/quantization/study/test.ipynb Cell 7' in <cell line: 1>()
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/pytorch-book/docs/quantization/study/test.ipynb#ch0000008vscode-remote?line=0'>1</a> from tools import train_model

ModuleNotFoundError: No module named 'tools'