PTQ 与 QAT 量化实践(PyTorch)#

  • 目标:PyTorch 浮点模型快速转换为 PTQ(静态)和 QAT。

  • 读者:会使用 PyTorch 实现卷积神经网络模型开发的算法工程师。

模型量化的动机

  • 更少的存储开销和带宽需求。即使用更少的比特数存储数据,有效减少应用对存储资源的依赖,但现代系统往往拥有相对丰富的存储资源,这一点已经不算是采用量化的主要动机;

  • 更快的计算速度。即对大多数处理器而言,整型运算的速度一般要比浮点运算更快一些(但不总是);

  • 更低的能耗与占用面积:FP32 乘法运算的能耗是 INT8 乘法运算能耗的 18.5 倍,芯片占用面积则是 INT8 的 27.3 倍,而对于芯片设计和 FPGA 设计而言,更少的资源占用意味着相同数量的单元下可以设计出更多的计算单元;而更少的能耗意味着更少的发热,和更长久的续航。

  • 尚可接受的精度损失。即量化相当于对模型权重引入噪声,所幸 CNN 本身对噪声不敏感(在模型训练过程中,模拟量化所引入的权重加噪还有利于防止过拟合),在合适的比特数下量化后的模型并不会带来很严重的精度损失。按照 GluonCV 提供的报告,经过 INT8 量化之后,ResNet50_v1 和 MobileNet1.0_v1 在 ILSVRC2012 数据集上的准确率仅分别从 77.36%、73.28% 下降为 76.86%、72.85%。

  • 支持 INT8 是一个大的趋势。即无论是移动端还是服务器端,都可以看到新的计算设备正不断迎合量化技术。比如 NPU/APU/AIPU 等基本都是支持 INT8(甚至更低精度的 INT4)计算的,并且有相当可观的 TOPs,而 Mali GPU 开始引入 INT8 dot 支持,Nvidia 也不例外。除此之外,当前很多创业公司新发布的边缘端芯片几乎都支持 INT8 类型。

量化(包括 PTQ (静态)和 QAT)的整个流程如下:

模块融合阶段

将已有模型 float_model 改造为可量化模型 quantizable_model,需要做如下工作:

  1. 算子替换:替换 float_model 的部分算子,比如 torch.add() 替换为 FloatFunctional.add()torch.cat() 替换为 FloatFunctional.cat()

  2. 算子融合:使用 fuse_modules()(用于静态 PTQ) 或者 fuse_modules_qat() (用于 QAT)融合如下模块序列:

    • conv, bn

    • conv, bn, relu

    • conv, relu

    • linear, bn

    • linear, relu

  3. 在模型 float_model 的开头和结尾分别插入 QuantStubDeQuantStub

  4. torch.nn.ReLU6 (如果存在的话)替换为 torch.nn.ReLU

QAT 配置和训练阶段

将可量化模型 quantizable_model 转换为 QAT 模型 qat_model

  1. quantizable_model.qconfig 赋值为 QConfig 类;

  2. 使用 prepare_qat() 函数,将其转换为 QAT 模型;

  3. 像普通的浮点模型一样训练 QAT 模型;

  4. torch.jit.save() 函数保存训练好的 QAT 模型。

模块融合#

可量化模型与浮点模型的算子总会存在一些差异,为了提供更加通用的接口,需要做如下工作。

算子替换#

FloatFunctional 算子比普通的 torch. 的运算多了后处理操作,比如:

import torch
from torch import Tensor

class FloatFunctional(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.activation_post_process = torch.nn.Identity()

    def forward(self, x):
        raise RuntimeError("FloatFunctional 不打算使用 `forward`。请使用下面的操作")

    def add(self, x: Tensor, y: Tensor) -> Tensor:
        """等价于 ``torch.add(Tensor, Tensor)`` 运算"""
        r = torch.add(x, y)
        r = self.activation_post_process(r)
        return r
    ...

由于 self.activation_post_process = torch.nn.Identity() 是自映射,所以 add() 等价于 torch.add()

小技巧

猜测 FloatFunctional 算子提供了自定义算子的官方接口。即只需要对 self.activation_post_process 赋值即可添加算子的后处理工作。

算子融合#

_fuse_modules() 提供了 fuse_modules()(用于 PTQ 静态量化) 和 fuse_modules_qat() (用于 QAT)的统一接口。

from torch.ao.quantization import fuse_modules_qat, fuse_modules


def _fuse_modules(
    model: nn.Module, 
    modules_to_fuse: list[str] | list[list[str]], 
    is_qat: bool | None, 
    **kwargs: Any
):
    if is_qat is None:
        is_qat = model.training
    method = fuse_modules_qat if is_qat else  fuse_modules
    return method(model, modules_to_fuse, **kwargs)

下面介绍几个例子。

示例#

为了说明模块融合的细节,举例如下。

改造 ResNet#

浮点模块

ResNet 的 BasicBlock 中需要被改造的部分:

  1. 24-2628-2932 需要被 _fuse_modules() 函数融合;

  2. 34-35 被替换为 FloatFunctionaladd_relu() 函数。

 1from torch import nn
 2
 3
 4class BasicBlock(nn.Module):
 5    expansion: int = 1
 6
 7    def __init__(
 8        self,
 9        inplanes: int,
10        planes: int,
11        stride: int = 1,
12        downsample: Optional[nn.Module] = None,
13        groups: int = 1,
14        base_width: int = 64,
15        dilation: int = 1,
16        norm_layer: Optional[Callable[..., nn.Module]] = None,
17        ) -> None:
18        super().__init__()
19        ... # 此处省略
20
21    def forward(self, x: Tensor) -> Tensor:
22        identity = x
23
24        out = self.conv1(x)
25        out = self.bn1(out)
26        out = self.relu(out)
27
28        out = self.conv2(out)
29        out = self.bn2(out)
30
31        if self.downsample is not None:
32            identity = self.downsample(x)
33
34        out += identity
35        out = self.relu(out)
36        return out

可量化模块

改造的部分:

  1. 722 实现算子替换工作;

  2. 12-1416-172026-2729-30 实现算子融合工作。

 1from torchvision.models.resnet import BasicBlock
 2
 3
 4class QuantizableBasicBlock(BasicBlock):
 5    def __init__(self, *args: Any, **kwargs: Any) -> None:
 6        super().__init__(*args, **kwargs)
 7        self.add_relu = torch.nn.quantized.FloatFunctional()
 8
 9    def forward(self, x: Tensor) -> Tensor:
10        identity = x
11
12        out = self.conv1(x)
13        out = self.bn1(out)
14        out = self.relu(out)
15
16        out = self.conv2(out)
17        out = self.bn2(out)
18
19        if self.downsample is not None:
20            identity = self.downsample(x)
21
22        out = self.add_relu.add_relu(out, identity)
23        return out
24
25    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
26        _fuse_modules(self, [["conv1", "bn1", "relu"],
27                            ["conv2", "bn2"]], is_qat, inplace=True)
28        if self.downsample:
29            _fuse_modules(self.downsample,
30                        ["0", "1"], is_qat, inplace=True)

量化配置#

量化训练#

实战 MobileNetv2#