测试 QAT
测试 QAT#
import torch
import torch.nn as nn
import time
import torch.quantization
# 设置 warnings
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module=r'.*'
)
warnings.filterwarnings(
action='default',
module=r'torch.quantization'
)
# 为可重复的结果指定随机种子
torch.manual_seed(191009)
from mod import load_mod
load_mod()
/media/pc/data/4tb/xinet/web/pytorch-book/docs
作为最后一个主要的设置步骤,我们为训练和测试集定义了数据加载器。
要使用整个 ImageNet 数据集运行本教程中的代码,请先按照 ImageNet Data 中的说明下载 ImageNet。将下载的文件解压缩到 data_path
文件夹中。
下载完数据后,我们将在下面展示一些函数,这些函数定义了用于读取数据的数据加载器。
from pytorch_book.datasets.imagenet import prepare_data_loaders
data_path = '/media/pc/data/4tb/xinet/datasets/imagenet2'
train_batch_size = 30
eval_batch_size = 50
data_loader, data_loader_test = prepare_data_loaders(data_path,
train_batch_size,
eval_batch_size)
训练模型
写一个通用函数 train_model()
来训练模型:
调度学习率
保存最佳模型
from tools import train_model
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
/media/pc/data/4tb/xinet/web/pytorch-book/docs/quantization/study/test.ipynb Cell 7' in <cell line: 1>()
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/pytorch-book/docs/quantization/study/test.ipynb#ch0000008vscode-remote?line=0'>1</a> from tools import train_model
ModuleNotFoundError: No module named 'tools'