QAT 的不同训练策略
QAT 的不同训练策略#
载入库:
import logging
from collections import namedtuple
import torch
from torch import nn, jit
from torch.ao.quantization import quantize_qat
from torchvision.models.quantization import mobilenet_v2
def create_model(num_classes=10,
quantize=False,
pretrained=False):
'''定义模型'''
float_model = mobilenet_v2(pretrained=pretrained,
quantize=quantize)
# 匹配 ``num_classes``
float_model.classifier[1] = nn.Linear(float_model.last_channel,
num_classes)
return float_model
def create_float_model(num_classes,
model_path):
model = create_model(quantize=False,
num_classes=num_classes)
model = load_model(model, model_path)
return model
def set_cudnn(cuda_path=":/usr/local/cuda/bin",
LD_LIBRARY_PATH="/usr/local/cuda/lib64"):
import os
os.environ["PATH"] += cuda_path
os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
torch.cuda.empty_cache() # 清空 GPU 缓存
print(torch.cuda.memory_summary()) # 打印显存
set_cudnn()
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Active memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Allocations | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Active allocs | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
# 设置 warnings
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module='.*'
)
warnings.filterwarnings(
action='ignore',
module='torch.ao.quantization'
)
# 载入自定义模块
from mod import torchq
from torchq.helper import evaluate, print_size_of_model, load_model
def print_info(model, model_type, num_eval, criterion):
'''打印信息'''
print_size_of_model(model)
top1, top5 = evaluate(model, criterion, test_iter)
print(f'\n{model_type}:\n\t'
f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')
Config = namedtuple('Config',
["net",
"device",
"train_iter",
"test_iter",
"loss",
"trainer",
"num_epochs",
"logger",
"need_qconfig",
"is_freeze",
"is_quantized_acc",
"backend",
"ylim"])
超参数设置:
saved_model_dir = 'models/'
model_name = "mobilenet"
logfile = f"outputs/{model_name}.log"
float_model_file = f'{model_name}_pretrained_float.pth'
logging.basicConfig(filename=logfile, filemode='w')
logger = logging.getLogger(name=f"{model_name}Logger")
logger.setLevel(logging.DEBUG)
# scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'
# 超参数
float_model_path = saved_model_dir + float_model_file
batch_size = 8
num_classes = 10
num_epochs = 50
learning_rate = 5e-5
ylim = [0.8, 1]
加载数据集:
from torchq.xinet import CV
# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224
train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,
resize=224)
num_eval = sum(len(ys) for _, ys in test_iter)
Files already downloaded and verified
Files already downloaded and verified
打印浮点模型信息:
float_model = create_float_model(num_classes, float_model_path)
model_type = '浮点模型'
criterion = nn.CrossEntropyLoss(reduction="none")
print_info(float_model, model_type, num_eval, criterion)
模型大小:9.187789 MB
Batch 0 ~ Acc@1 100.00 (100.00) Acc@5 100.00 (100.00)
Batch 500 ~ Acc@1 100.00 ( 95.08) Acc@5 100.00 ( 99.93)
Batch 1000 ~ Acc@1 100.00 ( 94.84) Acc@5 100.00 ( 99.91)
浮点模型:
在 10000 张图片上评估 accuracy 为: 94.91000
普通策略:
num_epochs = 30
ylim = [0.85, 1]
device = 'cuda:1'
param_group = True
# 量化参数
is_freeze = False
is_quantized_acc = False
need_qconfig = True # 做一些 QAT 的量化配置工作
# 提供位置参数
config = Config(train_iter,
test_iter,
learning_rate,
num_epochs,
logger,
device,
is_freeze,
is_quantized_acc,
need_qconfig,
param_group,
ylim)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb Cell 13' in <cell line: 12>()
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=8'>9</a> need_qconfig = True # 做一些 QAT 的量化配置工作
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=10'>11</a> # 提供位置参数
---> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=11'>12</a> config = Config(train_iter,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=12'>13</a> test_iter,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=13'>14</a> learning_rate,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=14'>15</a> num_epochs,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=15'>16</a> logger,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=16'>17</a> device,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=17'>18</a> is_freeze,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=18'>19</a> is_quantized_acc,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=19'>20</a> need_qconfig,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=20'>21</a> param_group,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=21'>22</a> ylim)
TypeError: Config.__new__() missing 2 required positional arguments: 'backend' and 'ylim'
The Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
"net",
"device",
"train_iter",
"test_iter",
"loss",
"trainer",
"num_epochs",
"logger",
"need_qconfig",
"is_freeze",
"is_quantized_acc",
"backend",
"ylim"
config
args = [train_iter,
test_iter,
learning_rate,
num_epochs,
device,
is_freeze,
is_quantized_acc,
need_qconfig,
param_group,
ylim]
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, config)
qat_model = create_float_model(num_classes, float_model_path)
qat_model.fuse_model() # 添加融合
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)
冻结前几次训练的量化器以及观测器:
args[5] = True
args[6] = False
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)
输出量化精度:
args[6] = True
args[5] = False
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)
冻结前几次训练的观测器并且生成量化精度:
args[5] = True
args[6] = True
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)
torch.nn.quantized.FloatFunctional