PTQ 与 QAT 量化实践(PyTorch)
导航
PTQ 与 QAT 量化实践(PyTorch)#
目标:PyTorch 浮点模型快速转换为 PTQ(静态)和 QAT。
读者:会使用 PyTorch 实现卷积神经网络模型开发的算法工程师。
模型量化的动机
更少的存储开销和带宽需求。即使用更少的比特数存储数据,有效减少应用对存储资源的依赖,但现代系统往往拥有相对丰富的存储资源,这一点已经不算是采用量化的主要动机;
更快的计算速度。即对大多数处理器而言,整型运算的速度一般要比浮点运算更快一些(但不总是);
更低的能耗与占用面积:FP32 乘法运算的能耗是 INT8 乘法运算能耗的 18.5 倍,芯片占用面积则是 INT8 的 27.3 倍,而对于芯片设计和 FPGA 设计而言,更少的资源占用意味着相同数量的单元下可以设计出更多的计算单元;而更少的能耗意味着更少的发热,和更长久的续航。
尚可接受的精度损失。即量化相当于对模型权重引入噪声,所幸 CNN 本身对噪声不敏感(在模型训练过程中,模拟量化所引入的权重加噪还有利于防止过拟合),在合适的比特数下量化后的模型并不会带来很严重的精度损失。按照 GluonCV 提供的报告,经过 INT8 量化之后,ResNet50_v1 和 MobileNet1.0_v1 在 ILSVRC2012 数据集上的准确率仅分别从 77.36%、73.28% 下降为 76.86%、72.85%。
支持 INT8 是一个大的趋势。即无论是移动端还是服务器端,都可以看到新的计算设备正不断迎合量化技术。比如 NPU/APU/AIPU 等基本都是支持 INT8(甚至更低精度的 INT4)计算的,并且有相当可观的 TOPs,而 Mali GPU 开始引入 INT8 dot 支持,Nvidia 也不例外。除此之外,当前很多创业公司新发布的边缘端芯片几乎都支持 INT8 类型。
量化(包括 PTQ (静态)和 QAT)的整个流程如下:
模块融合阶段
将已有模型 float_model
改造为可量化模型 quantizable_model
,需要做如下工作:
算子替换:替换
float_model
的部分算子,比如torch.add()
替换为FloatFunctional
.add()
,torch.cat()
替换为FloatFunctional
.cat()
算子融合:使用
fuse_modules()
(用于静态 PTQ) 或者fuse_modules_qat()
(用于 QAT)融合如下模块序列:conv, bn
conv, bn, relu
conv, relu
linear, bn
linear, relu
在模型
float_model
的开头和结尾分别插入QuantStub
和DeQuantStub
将
torch.nn.ReLU6
(如果存在的话)替换为torch.nn.ReLU
QAT 配置和训练阶段
将可量化模型 quantizable_model
转换为 QAT 模型 qat_model
:
quantizable_model.qconfig
赋值为QConfig
类;使用
prepare_qat()
函数,将其转换为 QAT 模型;像普通的浮点模型一样训练 QAT 模型;
torch.jit.save()
函数保存训练好的 QAT 模型。
模块融合#
可量化模型与浮点模型的算子总会存在一些差异,为了提供更加通用的接口,需要做如下工作。
算子替换#
FloatFunctional
算子比普通的 torch.
的运算多了后处理操作,比如:
import torch
from torch import Tensor
class FloatFunctional(torch.nn.Module):
def __init__(self):
super().__init__()
self.activation_post_process = torch.nn.Identity()
def forward(self, x):
raise RuntimeError("FloatFunctional 不打算使用 `forward`。请使用下面的操作")
def add(self, x: Tensor, y: Tensor) -> Tensor:
"""等价于 ``torch.add(Tensor, Tensor)`` 运算"""
r = torch.add(x, y)
r = self.activation_post_process(r)
return r
...
由于 self.activation_post_process = torch.nn.Identity()
是自映射,所以 add()
等价于 torch.add()
。
小技巧
猜测 FloatFunctional
算子提供了自定义算子的官方接口。即只需要对 self.activation_post_process
赋值即可添加算子的后处理工作。
算子融合#
_fuse_modules()
提供了 fuse_modules()
(用于 PTQ 静态量化) 和 fuse_modules_qat()
(用于 QAT)的统一接口。
from torch.ao.quantization import fuse_modules_qat, fuse_modules
def _fuse_modules(
model: nn.Module,
modules_to_fuse: list[str] | list[list[str]],
is_qat: bool | None,
**kwargs: Any
):
if is_qat is None:
is_qat = model.training
method = fuse_modules_qat if is_qat else fuse_modules
return method(model, modules_to_fuse, **kwargs)
下面介绍几个例子。
示例#
为了说明模块融合的细节,举例如下。
改造 ResNet#
浮点模块
ResNet 的 BasicBlock
中需要被改造的部分:
24-26
、28-29
与32
需要被_fuse_modules()
函数融合;34-35
被替换为FloatFunctional
的add_relu()
函数。
1from torch import nn
2
3
4class BasicBlock(nn.Module):
5 expansion: int = 1
6
7 def __init__(
8 self,
9 inplanes: int,
10 planes: int,
11 stride: int = 1,
12 downsample: Optional[nn.Module] = None,
13 groups: int = 1,
14 base_width: int = 64,
15 dilation: int = 1,
16 norm_layer: Optional[Callable[..., nn.Module]] = None,
17 ) -> None:
18 super().__init__()
19 ... # 此处省略
20
21 def forward(self, x: Tensor) -> Tensor:
22 identity = x
23
24 out = self.conv1(x)
25 out = self.bn1(out)
26 out = self.relu(out)
27
28 out = self.conv2(out)
29 out = self.bn2(out)
30
31 if self.downsample is not None:
32 identity = self.downsample(x)
33
34 out += identity
35 out = self.relu(out)
36 return out
可量化模块
改造的部分:
7
和22
实现算子替换工作;12-14
、16-17
、20
、26-27
、29-30
实现算子融合工作。
1from torchvision.models.resnet import BasicBlock
2
3
4class QuantizableBasicBlock(BasicBlock):
5 def __init__(self, *args: Any, **kwargs: Any) -> None:
6 super().__init__(*args, **kwargs)
7 self.add_relu = torch.nn.quantized.FloatFunctional()
8
9 def forward(self, x: Tensor) -> Tensor:
10 identity = x
11
12 out = self.conv1(x)
13 out = self.bn1(out)
14 out = self.relu(out)
15
16 out = self.conv2(out)
17 out = self.bn2(out)
18
19 if self.downsample is not None:
20 identity = self.downsample(x)
21
22 out = self.add_relu.add_relu(out, identity)
23 return out
24
25 def fuse_model(self, is_qat: Optional[bool] = None) -> None:
26 _fuse_modules(self, [["conv1", "bn1", "relu"],
27 ["conv2", "bn2"]], is_qat, inplace=True)
28 if self.downsample:
29 _fuse_modules(self.downsample,
30 ["0", "1"], is_qat, inplace=True)