QAT 的不同训练策略#

载入库:

import logging
from collections import namedtuple
import torch
from torch import nn, jit
from torch.ao.quantization import quantize_qat
from torchvision.models.quantization import mobilenet_v2


def create_model(num_classes=10,
                 quantize=False,
                 pretrained=False):
    '''定义模型'''
    float_model = mobilenet_v2(pretrained=pretrained,
                               quantize=quantize)
    # 匹配 ``num_classes``
    float_model.classifier[1] = nn.Linear(float_model.last_channel,
                                          num_classes)
    return float_model


def create_float_model(num_classes,
                       model_path):
    model = create_model(quantize=False,
                         num_classes=num_classes)
    model = load_model(model, model_path)
    return model

def set_cudnn(cuda_path=":/usr/local/cuda/bin",
              LD_LIBRARY_PATH="/usr/local/cuda/lib64"):
    import os
    os.environ["PATH"] += cuda_path
    os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
torch.cuda.empty_cache() # 清空 GPU 缓存
print(torch.cuda.memory_summary()) # 打印显存
set_cudnn()
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| GPU reserved memory   |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|
# 设置 warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module='.*'
)
warnings.filterwarnings(
    action='ignore',
    module='torch.ao.quantization'
)
# 载入自定义模块
from mod import torchq

from torchq.helper import evaluate, print_size_of_model, load_model

def print_info(model, model_type, num_eval, criterion):
    '''打印信息'''
    print_size_of_model(model)
    top1, top5 = evaluate(model, criterion, test_iter)
    print(f'\n{model_type}\n\t'
          f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')
Config = namedtuple('Config',
                    ["net",
                     "device",
                     "train_iter",
                     "test_iter",
                     "loss",
                     "trainer",
                     "num_epochs",
                     "logger",
                     "need_qconfig",
                     "is_freeze",
                     "is_quantized_acc",
                     "backend",
                     "ylim"])

超参数设置:

saved_model_dir = 'models/'
model_name = "mobilenet"
logfile = f"outputs/{model_name}.log"
float_model_file = f'{model_name}_pretrained_float.pth'
logging.basicConfig(filename=logfile, filemode='w')
logger = logging.getLogger(name=f"{model_name}Logger")
logger.setLevel(logging.DEBUG)
# scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'
# 超参数
float_model_path = saved_model_dir + float_model_file
batch_size = 8
num_classes = 10
num_epochs = 50
learning_rate = 5e-5
ylim = [0.8, 1]

加载数据集:

from torchq.xinet import CV

# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224
train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,
                                             resize=224)
num_eval = sum(len(ys) for _, ys in test_iter)
Files already downloaded and verified
Files already downloaded and verified

打印浮点模型信息:

float_model = create_float_model(num_classes, float_model_path)
model_type = '浮点模型'
criterion = nn.CrossEntropyLoss(reduction="none")
print_info(float_model, model_type, num_eval, criterion)
模型大小:9.187789 MB
Batch 0 ~ Acc@1 100.00 (100.00)	 Acc@5 100.00 (100.00)
Batch 500 ~ Acc@1 100.00 ( 95.08)	 Acc@5 100.00 ( 99.93)
Batch 1000 ~ Acc@1 100.00 ( 94.84)	 Acc@5 100.00 ( 99.91)

浮点模型:
	在 10000 张图片上评估 accuracy 为: 94.91000

普通策略:

num_epochs = 30
ylim = [0.85, 1]
device = 'cuda:1'
param_group = True

# 量化参数
is_freeze = False
is_quantized_acc = False
need_qconfig = True  # 做一些 QAT 的量化配置工作

# 提供位置参数
config = Config(train_iter,
                test_iter,
                learning_rate,
                num_epochs,
                logger,
                device,
                is_freeze,
                is_quantized_acc,
                need_qconfig,
                param_group,
                ylim)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb Cell 13' in <cell line: 12>()
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=8'>9</a> need_qconfig = True  # 做一些 QAT 的量化配置工作
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=10'>11</a> # 提供位置参数
---> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=11'>12</a> config = Config(train_iter,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=12'>13</a>                 test_iter,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=13'>14</a>                 learning_rate,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=14'>15</a>                 num_epochs,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=15'>16</a>                 logger,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=16'>17</a>                 device,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=17'>18</a>                 is_freeze,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=18'>19</a>                 is_quantized_acc,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=19'>20</a>                 need_qconfig,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=20'>21</a>                 param_group,
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/qat-fuse.ipynb#ch0000026vscode-remote?line=21'>22</a>                 ylim)

TypeError: Config.__new__() missing 2 required positional arguments: 'backend' and 'ylim'
The Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
"net",
                     "device",
                     "train_iter",
                     "test_iter",
                     "loss",
                     "trainer",
                     "num_epochs",
                     "logger",
                     "need_qconfig",
                     "is_freeze",
                     "is_quantized_acc",
                     "backend",
                     "ylim"
config
args = [train_iter,
        test_iter,
        learning_rate,
        num_epochs,
        device,
        is_freeze,
        is_quantized_acc,
        need_qconfig,
        param_group,
        ylim]
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, config)
qat_model = create_float_model(num_classes, float_model_path)
qat_model.fuse_model() # 添加融合
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)

冻结前几次训练的量化器以及观测器:

args[5] = True
args[6] = False
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)

输出量化精度:

args[6] = True
args[5] = False
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)

冻结前几次训练的观测器并且生成量化精度:

args[5] = True
args[6] = True
qat_model = create_float_model(num_classes, float_model_path)
quantized_model = quantize_qat(qat_model, CV.train_fine_tuning, args)
torch.nn.quantized.FloatFunctional