{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 通用量化模型\n",
"\n",
"```{rubric} 模型对比\n",
"```\n",
"\n",
"类型|大小(MB)|accuracy($\\%$)\n",
":-|:-|:-\n",
"浮点|9.188|95.09\n",
"浮点融合|8.924|95.09\n",
"QAT|2.657|94.86\n",
"\n",
"```{rubric} 不同 QConfig 的静态 PTQ 模型\n",
"```\n",
"\n",
"accuracy($\\%$)|激活|权重|\n",
":-|:-|:-\n",
"|40.51|{data}`~torch.ao.quantization.observer.MinMaxObserver`.`with_args(quant_min=0, quant_max=127)`|{data}`~torch.ao.quantization.observer.MinMaxObserver`.`with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)`\n",
"75.68|{data}`~torch.ao.quantization.observer.HistogramObserver`.`with_args(quant_min=0, quant_max=127)`|{data}`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.`with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)`\n",
"\n",
"加载一些库:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn, jit\n",
"from torch.ao.quantization.qconfig import default_qconfig\n",
"from torch.ao.quantization.qconfig import get_default_qat_qconfig, get_default_qconfig\n",
"from torch.ao.quantization.quantize import prepare, convert, prepare_qat\n",
"import torch\n",
"\n",
"# 设置 warnings\n",
"import warnings\n",
"warnings.filterwarnings(\n",
" action='ignore',\n",
" category=DeprecationWarning,\n",
" module='.*'\n",
")\n",
"warnings.filterwarnings(\n",
" action='ignore',\n",
" module='torch.ao.quantization'\n",
")\n",
"# 载入自定义模块\n",
"from mod import load_mod\n",
"load_mod()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"设置模型保存路径和超参数:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"saved_model_dir = 'models/'\n",
"float_model_file = 'mobilenet_pretrained_float.pth'\n",
"scripted_float_model_file = 'mobilenet_float_scripted.pth'\n",
"scripted_ptq_model_file = 'mobilenet_ptq_scripted.pth'\n",
"scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'\n",
"scripted_qat_model_file = 'mobilenet_qat_scripted_quantized.pth'\n",
"# 超参数\n",
"learning_rate = 5e-5\n",
"num_epochs = 30\n",
"batch_size = 16\n",
"num_classes = 10\n",
"# train_batch_size = 30\n",
"# eval_batch_size = 50\n",
"\n",
"# 设置评估策略\n",
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 辅助函数\n",
"\n",
"接下来,我们定义几个[帮助函数](https://github.com/pytorch/examples/blob/master/imagenet/main.py)来帮助评估模型。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from helper import evaluate, print_size_of_model, load_model\n",
"\n",
"\n",
"def print_info(model, model_type='浮点模型'):\n",
" '''打印信息'''\n",
" print_size_of_model(model)\n",
" top1, top5 = evaluate(model, criterion, test_iter)\n",
" print(f'\\n{model_type}:\\n\\t'\n",
" f'在 {num_eval} 张图片上评估 accuracy 为: {top1.avg:2.5f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 定义数据集和数据加载器\n",
"\n",
"作为最后一个主要的设置步骤,我们为训练和测试集定义了数据加载器。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from xinet import CV\n",
"\n",
"# 为了 cifar10 匹配 ImageNet,需要将其 resize 到 224\n",
"train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size,\n",
" resize=224)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"查看数据集的 batch 次数:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"训练、测试批次分别为: 3125 625\n"
]
}
],
"source": [
"print('训练、测试批次分别为:',\n",
" len(train_iter), len(test_iter))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"获取训练和测试数据集的大小:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(50000, 10000)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_train = sum(len(ys) for _, ys in train_iter)\n",
"num_eval = sum(len(ys) for _, ys in test_iter)\n",
"num_train, num_eval"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 微调浮点模型\n",
"\n",
"配置浮点模型:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from torchvision.models.quantization import mobilenet_v2\n",
"\n",
"# 定义模型\n",
"def create_model(quantize=False,\n",
" num_classes=10,\n",
" pretrained=False):\n",
" float_model = mobilenet_v2(pretrained=pretrained,\n",
" quantize=quantize)\n",
" # 匹配 ``num_classes``\n",
" float_model.classifier[1] = nn.Linear(float_model.last_channel,\n",
" num_classes)\n",
" return float_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"定义模型:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"float_model = create_model(pretrained=True,\n",
" quantize=False,\n",
" num_classes=num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"微调浮点模型:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.012, train acc 0.996, test acc 0.951\n",
"338.8 examples/sec on cuda:0\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
"