PTQ 小范例(eager)#

def print_version():
    import torch
    import torchvision

    # 查看核心包的版本
    print(f'torch: {torch.__version__} \n'
          f'torchvision: {torchvision.__version__}')

print_version()
torch: 2.1.0+cu121 
torchvision: 0.16.0+cu121

下面将逐步展开 PTQ 的知识点。

可量化模型(PTQ)#

定义浮点模块

import torch
from torch import nn, Tensor


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 3)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(3, 16, 1)

    def _forward_impl(self, x: Tensor) -> Tensor:
        '''提供便捷函数'''
        x = self.conv(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        x= self._forward_impl(x)
        return x

转换可量化模块

将浮点模块 M 转换为可量化模块 QM(量化流程的最关键的一步)。

from torch.ao.quantization import QuantStub, DeQuantStub


class QM(M):
    '''
    Args:
        is_print: 为了测试需求,打印一些信息
    '''
    def __init__(self, is_print: bool=False):
        super().__init__()
        self.is_print = is_print
        self.quant = QuantStub() # 将张量从浮点转换为量化
        self.dequant = DeQuantStub() # 将张量从量化转换为浮点

    def forward(self, x: Tensor) -> Tensor:
        if self.is_print:
            print('原始类型:', x.dtype)
        # 手动指定张量将在量化模型中从浮点模块转换为量化模块的位置
        x = self.quant(x)
        if self.is_print:
            print('量化前的类型:', x.dtype)
        x = self._forward_impl(x)
        if self.is_print:
            print('量化中的类型:',x.dtype)
        # 在量化模型中手动指定张量从量化到浮点的转换位置
        x = self.dequant(x)
        if self.is_print:
            print('量化后的类型:', x.dtype)
        return x

也可以简写为:

from torch.ao.quantization.stubs import QuantWrapper
QuantWrapper(M())
QuantWrapper(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (module): M(
    (conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (relu): ReLU()
    (conv2): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
  )
)

QuantStubDeQuantStub 在量化过程中起到观测者的作用。QuantStub 在训练阶段会将输入数据量化为较低的精度表示,而 DeQuantStub 层在推理阶段会将输入数据从较低的精度表示恢复为原始的高精度数据。这样,模型在推理时可以获得与训练阶段相同的性能。

input_fp32 = torch.randn(4, 1, 4, 4) # 输入的数据

m = QM(is_print=True)
x = m(input_fp32)
原始类型: torch.float32
量化前的类型: torch.float32
量化中的类型: torch.float32
量化后的类型: torch.float32

查看权重的数据类型:

m.conv.weight.dtype
torch.float32

可以看出,此时模块 m 是浮点模块。

要使 PTQ 生效,必须将模型设置为 eval 模式:

# 创建浮点模型实例
model_fp32 = QM(is_print=True)
model_fp32.eval()
QM(
  (conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (conv2): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

查看此时的数据类型:

input_fp32 = torch.randn(4, 1, 4, 4)

x = model_fp32(input_fp32)
print('激活和权重的数据类型分别为:'
      f'{x.dtype}, {model_fp32.conv.weight.dtype}')
原始类型: torch.float32
量化前的类型: torch.float32
量化中的类型: torch.float32
量化后的类型: torch.float32
激活和权重的数据类型分别为:torch.float32, torch.float32

定义观测器(PTQ)#

赋值实例变量 qconfig,其中包含关于要附加哪种观测器的信息:

  • 使用 'fbgemm' 用于带 AVX2 的 x86(没有AVX2,一些运算的实现效率很低);使用 'qnnpack' 用于 ARM CPU(通常出现在移动/嵌入式设备中)。

  • 其他量化配置,如选择对称或非对称量化和 MinMaxL2Norm 校准技术,可以在这里指定。

model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')

融合激活层(PTQ)#

在适用的地方,融合 activation 到前面的层(这需要根据模型架构手动完成)。常见的融合包括 conv + reluconv + batchnorm + relu

model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
                                                      [['conv', 'relu']])
                                                    
model_fp32_fused
QM(
  (conv): ConvReLU2d(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
  (relu): Identity()
  (conv2): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

可以看到 model_fp32_fusedConvReLU2d 融合 model_fp32 的两个层 convrelu

启用观测器(PTQ)#

在融合后的模块中启用观测器,用于在校准期间观测激活(activation)张量。

model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
model_fp32_prepared
QM(
  (conv): ConvReLU2d(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (relu): Identity()
  (conv2): Conv2d(
    3, 16, kernel_size=(1, 1), stride=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
)

校准准备好的模型(PTQ)#

校准准备好的模型,以确定量化参数的激活在现实世界的设置,校准具有代表性的数据集。

input_fp32 = torch.randn(4, 1, 4, 4)

x = model_fp32_prepared(input_fp32)
model_fp32_prepared
原始类型: torch.float32
量化前的类型: torch.float32
量化中的类型: torch.float32
量化后的类型: torch.float32
QM(
  (conv): ConvReLU2d(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=0.0, max_val=1.430925965309143)
  )
  (relu): Identity()
  (conv2): Conv2d(
    3, 16, kernel_size=(1, 1), stride=(1, 1)
    (activation_post_process): HistogramObserver(min_val=-0.8590439558029175, max_val=0.9416270852088928)
  )
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=-1.7396681308746338, max_val=2.5327765941619873)
  )
  (dequant): DeQuantStub()
)

模型转换(PTQ)#

备注

量化权重,计算和存储每个激活张量要使用的尺度(scale)和偏差(bias)值,并用量化实现替换关键算子。

转换已校准好的模型为量化模型:

model_int8 = torch.quantization.convert(model_fp32_prepared)
model_int8
QM(
  (conv): QuantizedConvReLU2d(1, 3, kernel_size=(3, 3), stride=(1, 1), scale=0.005608734209090471, zero_point=0)
  (relu): Identity()
  (conv2): QuantizedConv2d(3, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.007058007176965475, zero_point=122)
  (quant): Quantize(scale=tensor([0.0167]), zero_point=tensor([104]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

查看权重的数据类型:

model_int8.conv.weight().dtype
torch.qint8

可以看出此时权重的元素大小为 1 字节,而不是 FP32 的 4 字节:

model_int8.conv.weight().element_size()
1

运行模型,相关的计算将在 torch.qint8 中发生。

res = model_int8(input_fp32)
res.dtype
原始类型: torch.float32
量化前的类型: torch.quint8
量化中的类型: torch.quint8
量化后的类型: torch.float32
torch.float32
# import torch.ao.nn.quantized as nnq