{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# QAT 的不同训练策略\n",
"\n",
"载入库:"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, jit\n",
"from torch.ao.quantization import quantize_qat\n",
"from torchvision.models.quantization import mobilenet_v2\n",
"\n",
"# 载入自定义模块\n",
"from mod import torchq\n",
"from torchq.xinet import CV\n",
"from torchq.helper import evaluate, print_size_of_model, load_model\n",
"\n",
"\n",
"def create_model(num_classes=10,\n",
" quantize=False,\n",
" pretrained=False):\n",
" '''定义模型'''\n",
" float_model = mobilenet_v2(pretrained=pretrained,\n",
" quantize=quantize)\n",
" # 匹配 ``num_classes``\n",
" float_model.classifier[1] = nn.Linear(float_model.last_channel,\n",
" num_classes)\n",
" return float_model\n",
"\n",
"\n",
"def create_float_model(num_classes,\n",
" model_path):\n",
" model = create_model(quantize=False,\n",
" num_classes=num_classes)\n",
" model = load_model(model, model_path)\n",
" return model\n"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|===========================================================================|\n",
"| PyTorch CUDA memory summary, device ID 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n",
"|===========================================================================|\n",
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocated memory | 0 B | 0 B | 0 B | 0 B |\n",
"| from large pool | 0 B | 0 B | 0 B | 0 B |\n",
"| from small pool | 0 B | 0 B | 0 B | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| Active memory | 0 B | 0 B | 0 B | 0 B |\n",
"| from large pool | 0 B | 0 B | 0 B | 0 B |\n",
"| from small pool | 0 B | 0 B | 0 B | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |\n",
"| from large pool | 0 B | 0 B | 0 B | 0 B |\n",
"| from small pool | 0 B | 0 B | 0 B | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |\n",
"| from large pool | 0 B | 0 B | 0 B | 0 B |\n",
"| from small pool | 0 B | 0 B | 0 B | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocations | 0 | 0 | 0 | 0 |\n",
"| from large pool | 0 | 0 | 0 | 0 |\n",
"| from small pool | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Active allocs | 0 | 0 | 0 | 0 |\n",
"| from large pool | 0 | 0 | 0 | 0 |\n",
"| from small pool | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved segments | 0 | 0 | 0 | 0 |\n",
"| from large pool | 0 | 0 | 0 | 0 |\n",
"| from small pool | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable allocs | 0 | 0 | 0 | 0 |\n",
"| from large pool | 0 | 0 | 0 | 0 |\n",
"| from small pool | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize allocations | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize GPU segments | 0 | 0 | 0 | 0 |\n",
"|===========================================================================|\n",
"\n"
]
}
],
"source": [
"torch.cuda.empty_cache() # 清空 GPU 缓存\n",
"print(torch.cuda.memory_summary()) # 打印显存"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"# 设置 warnings\n",
"import warnings\n",
"warnings.filterwarnings(\n",
" action='ignore',\n",
" category=DeprecationWarning,\n",
" module='.*'\n",
")\n",
"warnings.filterwarnings(\n",
" action='ignore',\n",
" module='torch.ao.quantization'\n",
")\n",
"\n",
"\n",
"def print_info(model, model_type, num_eval, criterion):\n",
" '''打印信息'''\n",
" print_size_of_model(model)\n",
" top1, top5 = evaluate(model, criterion, test_iter)\n",
" print(f'\\n{model_type}:\\n\\t'\n",
" f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"超参数设置:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"saved_model_dir = 'models/'\n",
"float_model_file = 'mobilenet_pretrained_float.pth'\n",
"# scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'\n",
"# 超参数\n",
"float_model_path = saved_model_dir + float_model_file\n",
"batch_size = 16\n",
"num_classes = 10\n",
"num_epochs = 50\n",
"learning_rate = 5e-5\n",
"ylim = [0.8, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"加载数据集:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224\n",
"train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,\n",
" resize=224)\n",
"num_eval = sum(len(ys) for _, ys in test_iter)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"打印浮点模型信息:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"模型大小:9.187789 MB\n",
"\n",
"浮点模型:\n",
"\t在 10000 张图片上评估 accuracy 为: 94.91000\n"
]
}
],
"source": [
"float_model = create_float_model(num_classes, float_model_path)\n",
"model_type = '浮点模型'\n",
"criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
"print_info(float_model, model_type, num_eval, criterion)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"普通策略:"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.013, train acc 0.996, test acc 0.950\n",
"352.4 examples/sec on cuda:1\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
"