QAT 小范例(eager)#

QAT 在 训练过程中 模拟量化的效果,可以获得更高的 accuracy。在训练过程中,所有的计算都是在浮点上进行的,使用 fake_quant 模块通过夹紧和舍入的方式对量化效果进行建模,模拟 INT8 的效果。模型转换后,权值和激活被量化,激活在可能的情况下被融合到前一层。它通常与 CNN 一起使用,与 PTQ 相比具有更高的 accuracy。

import torch
from torch import nn, Tensor


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()

    def _forward_impl(self, x: Tensor) -> Tensor:
        '''提供便捷函数'''
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        x= self._forward_impl(x)
        return x

模型必须设置为训练模式,以便 QAT 可用:

model_fp32 = M()
model_fp32.train()
# 添加量化配置(与 PTQ 相同相似)
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')

融合 QAT 模块(eager)#

QAT 的模块融合与 PTQ 相同相似:

from torch.ao.quantization import fuse_modules_qat

model_fp32_fused = fuse_modules_qat(model_fp32,
                                    [['conv', 'bn', 'relu']])

准备 QAT 模型(eager)#

这将在模型中插入观测者和伪量化模块,它们将在校准期间观测权重和激活的张量。

model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(

训练 QAT 模型(eager)#

# 下文会编写实际的例子,此处没有显示
training_loop(model_fp32_prepared)

将观测到的模型转换为量化模型。需要:

  • 量化权重,计算和存储用于每个激活张量的尺度(scale)和偏差(bias)值,

  • 在适当的地方融合模块,并用量化实现替换关键算子。

model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/ao/quantization/utils.py:317: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  warnings.warn(

运行模型,相关的计算将在 torch.qint8 中发生。

res = model_int8(input_fp32)