快速入门#

参考:

  1. 量化实践

  2. fx graph 模式 POST TRAINING STATIC QUANTIZATION

本教程介绍基于 torch.fx 在 graph 模式下进行训练后静态量化的步骤。FX Graph 模式量化的优点:可以在模型上完全自动地执行量化,尽管可能需要一些努力使模型与 FX Graph 模式量化兼容(象征性地用 torch.fx 跟踪),将有单独的教程来展示如何使我们想量化的模型的一部分与 FX Graph 模式量化兼容。也有 FX Graph 模式后训练动态量化 教程。FX Graph 模式 API 如下所示:

import torch
from torch.quantization import get_default_qconfig
# Note that this is temporary, 
# we'll expose these functions to torch.quantization after official releasee
from torch.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
prepared_model = prepare_fx(float_model, qconfig_dict)  # fuse modules and insert observers
calibrate(prepared_model, valset)  # run calibration on sample data
quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model

FX Graph 模式量化的动机#

目前 PyTorch 存在 eager 模式量化:Static Quantization with Eager Mode in PyTorch

可以看到,该过程涉及到多个手动步骤,包括:

  • 显式地 quantize 和 dequantize activations,当浮点和量化运算混合在模型中时,这是非常耗时的。

  • 显式融合模块,这需要手动识别卷积序列、 batch norms 以及 relus 和其他融合模式。

  • PyTorch 张量运算需要特殊处理(如 addconcat 等)。

  • 函数式没有 first class 支持(functional.conv2dfunctional.linear 不会被量化)

这些需要的修改大多来自于 Eager 模式量化的潜在限制。Eager 模式在模块级工作,因为它不能检查实际运行的代码(在 forward 函数中),量化是通过模块交换实现的,不知道在 Eager 模式下 forward 函数中模块是如何使用的。因此,它需要用户手动插入 QuantStubDeQuantStub,以标记他们想要 quantize 或 dequantize 的点。在图模式中,可以检查在 forward 函数中执行的实际代码(例如 aten 函数调用),量化是通过模块和 graph 操作实现的。由于图模式对运行的代码具有完全的可见性,能够自动地找出要融合哪些模块,在哪里插入 observer 调用,quantize/dequantize 函数等,能够自动化整个量化过程。

FX Graph 模式量化的优点是:

  • 简化量化流程,最小化手动步骤

  • 开启了进行更高级别优化的可能性,如自动精度选择(automatic precision selection)

定义辅助函数和 Prepare Dataset#

首先进行必要的导入,定义一些辅助函数并准备数据。这些步骤与 PyTorch 中 使用 Eager 模式的静态量化 相同。

要使用整个 ImageNet 数据集运行本教程中的代码,首先按照 ImageNet Data 中的说明下载 ImageNet。将下载的文件解压缩到 data_path 文件夹中。

下载 torchvision resnet18 模型 并将其重命名为 models/resnet18_pretrained_float.pth

from torch_book.data import ImageNet


root = "/media/pc/data/4tb/lxw/datasets/ILSVRC"
saved_model_dir = 'models/'

dataset = ImageNet(root)
trainset = dataset.loader(batch_size=30, split="train")
valset = dataset.loader(batch_size=50, split="val")
import copy
from torchvision import models

model_name = "resnet18"
float_model = getattr(models, model_name)(pretrained=True)
float_model.eval()

# deepcopy the model since we need to keep the original model around
model_to_quantize = copy.deepcopy(float_model)

评估模式的模型#

对于训练后量化,需要将模型设置为评估模式。

model_to_quantize.eval();

使用 qconfig_dict 指定如何量化模型#

qconfig_dict = {"" : default_qconfig}

使用与 Eager 模式量化中相同的 qconfig, qconfig 只是用于激活和权重的 observers 的命名元组。qconfig_dict 是具有以下配置的字典:

qconfig = {
    " : qconfig_global,
    "sub" : qconfig_sub,
    "sub.fc" : qconfig_fc,
    "sub.conv": None
}
qconfig_dict = {
    # qconfig? means either a valid qconfig or None
    # optional, global config
    "": qconfig?,
    # optional, used for module and function types
    # could also be split into module_types and function_types if we prefer
    "object_type": [
        (torch.nn.Conv2d, qconfig?),
        (torch.nn.functional.add, qconfig?),
        ...,
    ],
    # optional, used for module names
    "module_name": [
        ("foo.bar", qconfig?)
        ...,
    ],
    # optional, matched in order, first match takes precedence
    "module_name_regex": [
        ("foo.*bar.*conv[0-9]+", qconfig?)
        ...,
    ],
    # priority (in increasing order): global, object_type, module_name_regex, module_name
    # qconfig == None means fusion and quantization should be skipped for anything
    # matching the rule

    # **api subject to change**
    # optional: specify the path for standalone modules
    # These modules are symbolically traced and quantized as one unit
    # so that the call to the submodule appears as one call_module
    # node in the forward graph of the GraphModule
    "standalone_module_name": [
        "submodule.standalone"
    ],
    "standalone_module_class": [
        StandaloneModuleClass
    ]
}

可以在 qconfig 文件 中找到与 qconfig 相关的实用函数:

from torch.quantization import get_default_qconfig, quantize_jit

qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

为静态后训练量化模型做准备#

import warnings
from torch.quantization.quantize_fx import prepare_fx

warnings.filterwarnings('ignore')

prepared_model = prepare_fx(model_to_quantize, qconfig_dict)

prepare_fx 将 BatchNorm 模块折叠到前面的 Conv2d 模块中,并在模型中的适当位置插入 observers。

print(prepared_model.graph)
graph():
    %x : torch.Tensor [#users=1] = placeholder[target=x]
    %activation_post_process_0 : [#users=1] = call_module[target=activation_post_process_0](args = (%x,), kwargs = {})
    %conv1 : [#users=1] = call_module[target=conv1](args = (%activation_post_process_0,), kwargs = {})
    %activation_post_process_1 : [#users=1] = call_module[target=activation_post_process_1](args = (%conv1,), kwargs = {})
    %maxpool : [#users=1] = call_module[target=maxpool](args = (%activation_post_process_1,), kwargs = {})
    %activation_post_process_2 : [#users=2] = call_module[target=activation_post_process_2](args = (%maxpool,), kwargs = {})
    %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%activation_post_process_2,), kwargs = {})
    %activation_post_process_3 : [#users=1] = call_module[target=activation_post_process_3](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%activation_post_process_3,), kwargs = {})
    %activation_post_process_4 : [#users=1] = call_module[target=activation_post_process_4](args = (%layer1_0_conv2,), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_4, %activation_post_process_2), kwargs = {})
    %layer1_0_relu_1 : [#users=1] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %activation_post_process_5 : [#users=2] = call_module[target=activation_post_process_5](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%activation_post_process_5,), kwargs = {})
    %activation_post_process_6 : [#users=1] = call_module[target=activation_post_process_6](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%activation_post_process_6,), kwargs = {})
    %activation_post_process_7 : [#users=1] = call_module[target=activation_post_process_7](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_7, %activation_post_process_5), kwargs = {})
    %layer1_1_relu_1 : [#users=1] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %activation_post_process_8 : [#users=2] = call_module[target=activation_post_process_8](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_conv1 : [#users=1] = call_module[target=layer2.0.conv1](args = (%activation_post_process_8,), kwargs = {})
    %activation_post_process_9 : [#users=1] = call_module[target=activation_post_process_9](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_conv2 : [#users=1] = call_module[target=layer2.0.conv2](args = (%activation_post_process_9,), kwargs = {})
    %activation_post_process_10 : [#users=1] = call_module[target=activation_post_process_10](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [#users=1] = call_module[target=layer2.0.downsample.0](args = (%activation_post_process_8,), kwargs = {})
    %activation_post_process_11 : [#users=1] = call_module[target=activation_post_process_11](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_10, %activation_post_process_11), kwargs = {})
    %layer2_0_relu_1 : [#users=1] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %activation_post_process_12 : [#users=2] = call_module[target=activation_post_process_12](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_conv1 : [#users=1] = call_module[target=layer2.1.conv1](args = (%activation_post_process_12,), kwargs = {})
    %activation_post_process_13 : [#users=1] = call_module[target=activation_post_process_13](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_conv2 : [#users=1] = call_module[target=layer2.1.conv2](args = (%activation_post_process_13,), kwargs = {})
    %activation_post_process_14 : [#users=1] = call_module[target=activation_post_process_14](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_14, %activation_post_process_12), kwargs = {})
    %layer2_1_relu_1 : [#users=1] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %activation_post_process_15 : [#users=2] = call_module[target=activation_post_process_15](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_conv1 : [#users=1] = call_module[target=layer3.0.conv1](args = (%activation_post_process_15,), kwargs = {})
    %activation_post_process_16 : [#users=1] = call_module[target=activation_post_process_16](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_conv2 : [#users=1] = call_module[target=layer3.0.conv2](args = (%activation_post_process_16,), kwargs = {})
    %activation_post_process_17 : [#users=1] = call_module[target=activation_post_process_17](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [#users=1] = call_module[target=layer3.0.downsample.0](args = (%activation_post_process_15,), kwargs = {})
    %activation_post_process_18 : [#users=1] = call_module[target=activation_post_process_18](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_17, %activation_post_process_18), kwargs = {})
    %layer3_0_relu_1 : [#users=1] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %activation_post_process_19 : [#users=2] = call_module[target=activation_post_process_19](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_conv1 : [#users=1] = call_module[target=layer3.1.conv1](args = (%activation_post_process_19,), kwargs = {})
    %activation_post_process_20 : [#users=1] = call_module[target=activation_post_process_20](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_conv2 : [#users=1] = call_module[target=layer3.1.conv2](args = (%activation_post_process_20,), kwargs = {})
    %activation_post_process_21 : [#users=1] = call_module[target=activation_post_process_21](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_21, %activation_post_process_19), kwargs = {})
    %layer3_1_relu_1 : [#users=1] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %activation_post_process_22 : [#users=2] = call_module[target=activation_post_process_22](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_conv1 : [#users=1] = call_module[target=layer4.0.conv1](args = (%activation_post_process_22,), kwargs = {})
    %activation_post_process_23 : [#users=1] = call_module[target=activation_post_process_23](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_conv2 : [#users=1] = call_module[target=layer4.0.conv2](args = (%activation_post_process_23,), kwargs = {})
    %activation_post_process_24 : [#users=1] = call_module[target=activation_post_process_24](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [#users=1] = call_module[target=layer4.0.downsample.0](args = (%activation_post_process_22,), kwargs = {})
    %activation_post_process_25 : [#users=1] = call_module[target=activation_post_process_25](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_24, %activation_post_process_25), kwargs = {})
    %layer4_0_relu_1 : [#users=1] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %activation_post_process_26 : [#users=2] = call_module[target=activation_post_process_26](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_conv1 : [#users=1] = call_module[target=layer4.1.conv1](args = (%activation_post_process_26,), kwargs = {})
    %activation_post_process_27 : [#users=1] = call_module[target=activation_post_process_27](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_conv2 : [#users=1] = call_module[target=layer4.1.conv2](args = (%activation_post_process_27,), kwargs = {})
    %activation_post_process_28 : [#users=1] = call_module[target=activation_post_process_28](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_28, %activation_post_process_26), kwargs = {})
    %layer4_1_relu_1 : [#users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %activation_post_process_29 : [#users=1] = call_module[target=activation_post_process_29](args = (%layer4_1_relu_1,), kwargs = {})
    %avgpool : [#users=1] = call_module[target=avgpool](args = (%activation_post_process_29,), kwargs = {})
    %activation_post_process_30 : [#users=1] = call_module[target=activation_post_process_30](args = (%avgpool,), kwargs = {})
    %flatten : [#users=1] = call_function[target=torch.flatten](args = (%activation_post_process_30, 1), kwargs = {})
    %activation_post_process_31 : [#users=1] = call_module[target=activation_post_process_31](args = (%flatten,), kwargs = {})
    %fc : [#users=1] = call_module[target=fc](args = (%activation_post_process_31,), kwargs = {})
    %activation_post_process_32 : [#users=1] = call_module[target=activation_post_process_32](args = (%fc,), kwargs = {})
    return activation_post_process_32

校准#

将 observers 插入模型后,运行校准函数。校准的目的就是通过一些样本运行代表性的工作负载(例如样本的训练数据集)以便 observers 在模型中能够观测到张量的统计数据,以后使用这些信息来计算量化参数。

import torch

def calibrate(model, data_loader, samples=500):
    model.eval()
    with torch.no_grad():
        k = 0
        for image, _ in data_loader:
            if k > samples:
                break
            model(image)
            k += len(image)

calibrate(prepared_model, trainset)  # run calibration on sample data

将模型转换为量化模型#

convert_fx 采用校准模型并产生量化模型。

from torch.quantization.quantize_fx import convert_fx

quantized_model = convert_fx(prepared_model)
print(quantized_model)
GraphModule(
  (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.03267836198210716, zero_point=0, padding=(3, 3))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.019193191081285477, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.051562923938035965, zero_point=75, padding=(1, 1))
    )
    (1): Module(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.019093887880444527, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.06979087740182877, zero_point=78, padding=(1, 1))
    )
  )
  (layer2): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.01557458657771349, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.050476107746362686, zero_point=68, padding=(1, 1))
      (downsample): Module(
        (0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), scale=0.039443813264369965, zero_point=60)
      )
    )
    (1): Module(
      (conv1): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.016193654388189316, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.05214320868253708, zero_point=68, padding=(1, 1))
    )
  )
  (layer3): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(2, 2), scale=0.018163194879889488, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.05316956341266632, zero_point=51, padding=(1, 1))
      (downsample): Module(
        (0): QuantizedConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), scale=0.01836947724223137, zero_point=107)
      )
    )
    (1): Module(
      (conv1): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.013543782755732536, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.048523254692554474, zero_point=71, padding=(1, 1))
    )
  )
  (layer4): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(256, 512, kernel_size=(3, 3), stride=(2, 2), scale=0.014485283754765987, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.0517863854765892, zero_point=64, padding=(1, 1))
      (downsample): Module(
        (0): QuantizedConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), scale=0.04331441596150398, zero_point=58)
      )
    )
    (1): Module(
      (conv1): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.021167000755667686, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.22766999900341034, zero_point=45, padding=(1, 1))
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): QuantizedLinear(in_features=512, out_features=1000, scale=0.27226582169532776, zero_point=35, qscheme=torch.per_channel_affine)
)



def forward(self, x : torch.Tensor):
    conv1_input_scale_0 = self.conv1_input_scale_0
    conv1_input_zero_point_0 = self.conv1_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8);  x = conv1_input_scale_0 = conv1_input_zero_point_0 = None
    conv1 = self.conv1(quantize_per_tensor);  quantize_per_tensor = None
    maxpool = self.maxpool(conv1);  conv1 = None
    layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
    layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_conv1);  layer1_0_conv1 = None
    layer1_0_relu_scale_0 = self.layer1_0_relu_scale_0
    layer1_0_relu_zero_point_0 = self.layer1_0_relu_zero_point_0
    add_relu = torch.ops.quantized.add_relu(layer1_0_conv2, maxpool, layer1_0_relu_scale_0, layer1_0_relu_zero_point_0);  layer1_0_conv2 = maxpool = layer1_0_relu_scale_0 = layer1_0_relu_zero_point_0 = None
    layer1_1_conv1 = getattr(self.layer1, "1").conv1(add_relu)
    layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_conv1);  layer1_1_conv1 = None
    layer1_1_relu_scale_0 = self.layer1_1_relu_scale_0
    layer1_1_relu_zero_point_0 = self.layer1_1_relu_zero_point_0
    add_relu_1 = torch.ops.quantized.add_relu(layer1_1_conv2, add_relu, layer1_1_relu_scale_0, layer1_1_relu_zero_point_0);  layer1_1_conv2 = add_relu = layer1_1_relu_scale_0 = layer1_1_relu_zero_point_0 = None
    layer2_0_conv1 = getattr(self.layer2, "0").conv1(add_relu_1)
    layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_conv1);  layer2_0_conv1 = None
    layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(add_relu_1);  add_relu_1 = None
    layer2_0_relu_scale_0 = self.layer2_0_relu_scale_0
    layer2_0_relu_zero_point_0 = self.layer2_0_relu_zero_point_0
    add_relu_2 = torch.ops.quantized.add_relu(layer2_0_conv2, layer2_0_downsample_0, layer2_0_relu_scale_0, layer2_0_relu_zero_point_0);  layer2_0_conv2 = layer2_0_downsample_0 = layer2_0_relu_scale_0 = layer2_0_relu_zero_point_0 = None
    layer2_1_conv1 = getattr(self.layer2, "1").conv1(add_relu_2)
    layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_conv1);  layer2_1_conv1 = None
    layer2_1_relu_scale_0 = self.layer2_1_relu_scale_0
    layer2_1_relu_zero_point_0 = self.layer2_1_relu_zero_point_0
    add_relu_3 = torch.ops.quantized.add_relu(layer2_1_conv2, add_relu_2, layer2_1_relu_scale_0, layer2_1_relu_zero_point_0);  layer2_1_conv2 = add_relu_2 = layer2_1_relu_scale_0 = layer2_1_relu_zero_point_0 = None
    layer3_0_conv1 = getattr(self.layer3, "0").conv1(add_relu_3)
    layer3_0_conv2 = getattr(self.layer3, "0").conv2(layer3_0_conv1);  layer3_0_conv1 = None
    layer3_0_downsample_0 = getattr(getattr(self.layer3, "0").downsample, "0")(add_relu_3);  add_relu_3 = None
    layer3_0_relu_scale_0 = self.layer3_0_relu_scale_0
    layer3_0_relu_zero_point_0 = self.layer3_0_relu_zero_point_0
    add_relu_4 = torch.ops.quantized.add_relu(layer3_0_conv2, layer3_0_downsample_0, layer3_0_relu_scale_0, layer3_0_relu_zero_point_0);  layer3_0_conv2 = layer3_0_downsample_0 = layer3_0_relu_scale_0 = layer3_0_relu_zero_point_0 = None
    layer3_1_conv1 = getattr(self.layer3, "1").conv1(add_relu_4)
    layer3_1_conv2 = getattr(self.layer3, "1").conv2(layer3_1_conv1);  layer3_1_conv1 = None
    layer3_1_relu_scale_0 = self.layer3_1_relu_scale_0
    layer3_1_relu_zero_point_0 = self.layer3_1_relu_zero_point_0
    add_relu_5 = torch.ops.quantized.add_relu(layer3_1_conv2, add_relu_4, layer3_1_relu_scale_0, layer3_1_relu_zero_point_0);  layer3_1_conv2 = add_relu_4 = layer3_1_relu_scale_0 = layer3_1_relu_zero_point_0 = None
    layer4_0_conv1 = getattr(self.layer4, "0").conv1(add_relu_5)
    layer4_0_conv2 = getattr(self.layer4, "0").conv2(layer4_0_conv1);  layer4_0_conv1 = None
    layer4_0_downsample_0 = getattr(getattr(self.layer4, "0").downsample, "0")(add_relu_5);  add_relu_5 = None
    layer4_0_relu_scale_0 = self.layer4_0_relu_scale_0
    layer4_0_relu_zero_point_0 = self.layer4_0_relu_zero_point_0
    add_relu_6 = torch.ops.quantized.add_relu(layer4_0_conv2, layer4_0_downsample_0, layer4_0_relu_scale_0, layer4_0_relu_zero_point_0);  layer4_0_conv2 = layer4_0_downsample_0 = layer4_0_relu_scale_0 = layer4_0_relu_zero_point_0 = None
    layer4_1_conv1 = getattr(self.layer4, "1").conv1(add_relu_6)
    layer4_1_conv2 = getattr(self.layer4, "1").conv2(layer4_1_conv1);  layer4_1_conv1 = None
    layer4_1_relu_scale_0 = self.layer4_1_relu_scale_0
    layer4_1_relu_zero_point_0 = self.layer4_1_relu_zero_point_0
    add_relu_7 = torch.ops.quantized.add_relu(layer4_1_conv2, add_relu_6, layer4_1_relu_scale_0, layer4_1_relu_zero_point_0);  layer4_1_conv2 = add_relu_6 = layer4_1_relu_scale_0 = layer4_1_relu_zero_point_0 = None
    avgpool = self.avgpool(add_relu_7);  add_relu_7 = None
    flatten = torch.flatten(avgpool, 1);  avgpool = None
    fc = self.fc(flatten);  flatten = None
    dequantize_14 = fc.dequantize();  fc = None
    return dequantize_14
    

评估#

现在可以打印量化模型的大小和精度。

from torch_book.contrib.helper import evaluate, print_size_of_model
from torch import nn


criterion = nn.CrossEntropyLoss()

print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, valset)
Size of model before quantization
模型大小(MB):46.873073 MB
Size of model after quantization
模型大小(MB):11.853109 MB
print(f"[before serilaization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
[before serilaization] Evaluation accuracy on test dataset: 69.37, 88.89
fx_graph_mode_model_file_path = saved_model_dir + f"{model_name}_fx_graph_mode_quantized.pth"

# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path)

# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path))

# save with script
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)

top1, top5 = evaluate(loaded_quantized_model, criterion, valset)
print(f"[after serialization/deserialization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
[after serialization/deserialization] Evaluation accuracy on test dataset: 69.37, 88.89

如果希望获得更好的精度或性能,请尝试更改 qconfig_dict

调试量化模型#

还可以打印量化的 un-quantized conv 的权重来查看区别,首先显式地调用 fuse 来融合模型中的 conv 和 bn:注意,fuse_fx 只在 eval 模式下工作。

from torch.quantization.quantize_fx import fuse_fx

fused = fuse_fx(float_model)

conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]

print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))
tensor(0.0007, grad_fn=<MaxBackward1>)

基线浮点模型和 Eager 模式量化的比较#

scripted_float_model_file = "resnet18_scripted.pth"

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, valset)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)
Size of baseline model
模型大小(MB):46.874273 MB
Baseline Float Model Evaluation accuracy: 69.76, 89.08

在本节中,将量化模型与 FX Graph 模式的量化模型与在 Eager 模式下量化的模型进行比较。FX Graph 模式和 Eager 模式产生的量化模型非常相似,因此期望精度和 speedup 也很相似。

print("Size of Fx graph mode quantized model")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, valset)
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print_size_of_model(eager_quantized_model)
top1, top5 = evaluate(eager_quantized_model, criterion, valset)
print("eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)
Size of Fx graph mode quantized model
模型大小(MB):11.855297 MB
FX graph mode quantized model Evaluation accuracy on test dataset: 69.37, 88.89
Size of eager mode quantized model
模型大小(MB):11.850395 MB
eager mode quantized model Evaluation accuracy on test dataset: 69.50, 88.88

可以看到 FX Graph 模式和 Eager 模式量化模型的模型大小和精度是非常相似的。

在 AIBench 中运行模型(单线程)会得到如下结果:

Scripted Float Model:
Self CPU time total: 192.48ms

Scripted Eager Mode Quantized Model:
Self CPU time total: 50.76ms

Scripted FX Graph Mode Quantized Model:
Self CPU time total: 50.63ms

可以看到,对于 resnet18, FX Graph 模式和 Eager 模式量化模型都比浮点模型获得了相似的速度,大约比浮点模型快 2-4 倍。但是浮点模型上的实际加速可能会因模型、设备、构建、输入批大小、线程等而不同。