QAT 小范例(eager)#
QAT 在 训练过程中 模拟量化的效果,可以获得更高的 accuracy。在训练过程中,所有的计算都是在浮点上进行的,使用 fake_quant 模块通过夹紧和舍入的方式对量化效果进行建模,模拟 INT8 的效果。模型转换后,权值和激活被量化,激活在可能的情况下被融合到前一层。它通常与 CNN 一起使用,与 PTQ 相比具有更高的 accuracy。
import torch
from torch import nn, Tensor
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
def _forward_impl(self, x: Tensor) -> Tensor:
'''提供便捷函数'''
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def forward(self, x: Tensor) -> Tensor:
x= self._forward_impl(x)
return x
模型必须设置为训练模式,以便 QAT 可用:
model_fp32 = M()
model_fp32.train()
# 添加量化配置(与 PTQ 相同相似)
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
融合 QAT 模块(eager)#
QAT 的模块融合与 PTQ 相同相似:
from torch.ao.quantization import fuse_modules_qat
model_fp32_fused = fuse_modules_qat(model_fp32,
[['conv', 'bn', 'relu']])
准备 QAT 模型(eager)#
这将在模型中插入观测者和伪量化模块,它们将在校准期间观测权重和激活的张量。
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
warnings.warn(
训练 QAT 模型(eager)#
# 下文会编写实际的例子,此处没有显示
training_loop(model_fp32_prepared)
将观测到的模型转换为量化模型。需要:
量化权重,计算和存储用于每个激活张量的尺度(scale)和偏差(bias)值,
在适当的地方融合模块,并用量化实现替换关键算子。
model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/ao/quantization/utils.py:317: UserWarning: must run observer before calling calculate_qparams. Returning default values.
warnings.warn(
运行模型,相关的计算将在 torch.qint8
中发生。
res = model_int8(input_fp32)