FX 图模式后训练静态量化#

教程 介绍了基于 torch.fx 在图模式下进行静态量化的步骤。

FX 图模式量化的优点:可以完全自动地对模型进行量化。

大致流程

import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping
float_model.eval()
# 旧的 'fbgemm' 仍然可用,但 'x86' 是推荐的默认值。
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
example_inputs = (next(iter(data_loader))[0]) # 获取样本输入
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)  # 融合模块并插入观测器
calibrate(prepared_model, data_loader_test)  # 在样本数据上进行校准
quantized_model = convert_fx(prepared_model)  # 将校准模型转换为量化模型

FX 图模式量化的动机#

目前,PyTorch 只有 eager 模式量化作为替代:静态量化

eager 模式量化过程中涉及到多个手动步骤,包括:

  • 显式量化和反量化激活——当模型中时同时存在浮点运算和量化运算,这很耗时。

  • 显式地融合模块——这需要手动识别卷积序列、batch norms 和 relus 以及其他融合模式。

  • Pytorch 张量运算(如 addconcat 等)需要特殊处理。

  • 函数没有 first class 的支持(functional.conv2dfunctional.linear 不会被量化)

这些需要修改的大部分都来自 eager 模式量化的潜在局限性。Eager 模式工作在模块级别,因为它不能检查实际运行的代码(forward() 函数),量化是通过模块交换来实现的,不知道 forward() 函数中的模块是如何使用的,所以它需要用户手动插入 QuantStub 和 DeQuantStub 来标记想要量化或反量化的点。在图模式中,可以检查在 forward() 函数中执行的实际代码(例如,aten() 函数调用),量化是通过模块和图操作实现的。由于图模式对正在运行的代码具有完全可见性,这样能够自动地找出哪些模块需要融合,在哪里插入观测器调用,量化/反量化函数等,能够自动化整个量化过程。

FX 图模式量化的优点是:

  • 简化量化流程,最小的手工步骤

  • 解锁了进行更高级别优化的可能性,如自动精度选择

定义辅助函数并准备数据集#

定义辅助函数:

from copy import deepcopy
import torch
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
from torch import nn
from tqdm import tqdm

# 为可重复的结果指定随机种子
_ = torch.manual_seed(191009)
Hide code cell source
import os
from dataclasses import dataclass

def size_of_model(model):
    """返回模型大小"""
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize('temp.p')
    os.remove("temp.p")
    return size

@dataclass
class AverageMeter:
    """计算并存储平均值和当前值"""
    name: str
    fmt: str = ".3g"

    def __post_init__(self):
        self.reset()

    def reset(self):
        self.value = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, value, n=1):
        self.value = value
        self.sum += value * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        return f"{self.name} {self.value:{self.fmt}} ({self.avg:{self.fmt}})"

@torch.no_grad
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = max(topk)
    batch_size = target.shape[0]

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

@torch.no_grad
def evaluate(model, data_loader):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    for image, target in tqdm(data_loader):
        output = model(image)
        # criterion(output, target) # 计算损失
        cnt += 1
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))
    return top1, top5

准备数据和浮点模型:

from torchvision.models.resnet import resnet18, ResNet18_Weights
from torch_book.testing.imagenet import ImageNet

saved_model_dir = 'data/' # 模型存储路径
train_batch_size = 30 # 训练样本批量大小
eval_batch_size = 50 # 测试样本批量大小
dataset = ImageNet('/media/pc/data/lxw/home/data/datasets/ILSVRC')
data_loader = dataset.trainloader(train_batch_size)
data_loader_test = dataset.testloader(eval_batch_size)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = resnet18(weights=ResNet18_Weights.DEFAULT)
float_model = float_model.to("cpu").eval()
# 深度复制模型,因为需要保持原始模型
model_to_quantize = deepcopy(float_model)

设置模型为评估模式#

对于训练后量化,需要将模型设置为评估模式。

model_to_quantize.eval();

使用 QConfigMapping 指定如何量化模型#

qconfig_mapping = QConfigMapping.set_global(default_qconfig)

使用与 eager 模式量化相同的 qconfig, qconfig 只是用于激活和权重的观测器的命名元组。QConfigMapping 包含 opsqconfigs 的映射信息:

model_to_quantize.eval()
qconfig_mapping = (QConfigMapping()
    .set_global(qconfig_opt)  # qconfig_opt 是可选的 qconfig,可以是有效的 qconfig,也可以是 None
    .set_object_type(torch.nn.Conv2d, qconfig_opt) # 可以是可调用的
    .set_object_type("torch.nn.functional.add", qconfig_opt) # 或者类名的字符串
    .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig_opt) # 按顺序匹配,第一个匹配优先
    .set_module_name("foo.bar", qconfig_opt)
    .set_module_name_object_type_order()
)

备注

  • 优先级(按递增顺序):global、object_type、module_name_regex、module_name

  • qconfig == None 表示任何事情都应该跳过融合和量化

  • 匹配规则(除非找到更高优先级的匹配)

qconfig 相关的实用函数可以在 qconfig 文件中找到,而与 QConfigMapping 相关的实用函数可以在 qconfig_mapping 中找到。

# 旧的 'fbgemm' 仍然可用,但 'x86' 是推荐的默认值。
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)

为训练后静态量化准备模型#

prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)

prepare_fxBatchNorm 模块折叠到其前面的 Conv2d 模块中,并在模型的适当位置插入观测器。

prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
print(prepared_model.graph)
Hide code cell output
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %activation_post_process_0 : [num_users=1] = call_module[target=activation_post_process_0](args = (%x,), kwargs = {})
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%activation_post_process_0,), kwargs = {})
    %activation_post_process_1 : [num_users=1] = call_module[target=activation_post_process_1](args = (%conv1,), kwargs = {})
    %maxpool : [num_users=1] = call_module[target=maxpool](args = (%activation_post_process_1,), kwargs = {})
    %activation_post_process_2 : [num_users=2] = call_module[target=activation_post_process_2](args = (%maxpool,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%activation_post_process_2,), kwargs = {})
    %activation_post_process_3 : [num_users=1] = call_module[target=activation_post_process_3](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%activation_post_process_3,), kwargs = {})
    %activation_post_process_4 : [num_users=1] = call_module[target=activation_post_process_4](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_4, %activation_post_process_2), kwargs = {})
    %layer1_0_relu_1 : [num_users=1] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %activation_post_process_5 : [num_users=2] = call_module[target=activation_post_process_5](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%activation_post_process_5,), kwargs = {})
    %activation_post_process_6 : [num_users=1] = call_module[target=activation_post_process_6](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%activation_post_process_6,), kwargs = {})
    %activation_post_process_7 : [num_users=1] = call_module[target=activation_post_process_7](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_7, %activation_post_process_5), kwargs = {})
    %layer1_1_relu_1 : [num_users=1] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %activation_post_process_8 : [num_users=2] = call_module[target=activation_post_process_8](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%activation_post_process_8,), kwargs = {})
    %activation_post_process_9 : [num_users=1] = call_module[target=activation_post_process_9](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%activation_post_process_9,), kwargs = {})
    %activation_post_process_10 : [num_users=1] = call_module[target=activation_post_process_10](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%activation_post_process_8,), kwargs = {})
    %activation_post_process_11 : [num_users=1] = call_module[target=activation_post_process_11](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_10, %activation_post_process_11), kwargs = {})
    %layer2_0_relu_1 : [num_users=1] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %activation_post_process_12 : [num_users=2] = call_module[target=activation_post_process_12](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%activation_post_process_12,), kwargs = {})
    %activation_post_process_13 : [num_users=1] = call_module[target=activation_post_process_13](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%activation_post_process_13,), kwargs = {})
    %activation_post_process_14 : [num_users=1] = call_module[target=activation_post_process_14](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_14, %activation_post_process_12), kwargs = {})
    %layer2_1_relu_1 : [num_users=1] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %activation_post_process_15 : [num_users=2] = call_module[target=activation_post_process_15](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%activation_post_process_15,), kwargs = {})
    %activation_post_process_16 : [num_users=1] = call_module[target=activation_post_process_16](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%activation_post_process_16,), kwargs = {})
    %activation_post_process_17 : [num_users=1] = call_module[target=activation_post_process_17](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%activation_post_process_15,), kwargs = {})
    %activation_post_process_18 : [num_users=1] = call_module[target=activation_post_process_18](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_17, %activation_post_process_18), kwargs = {})
    %layer3_0_relu_1 : [num_users=1] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %activation_post_process_19 : [num_users=2] = call_module[target=activation_post_process_19](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%activation_post_process_19,), kwargs = {})
    %activation_post_process_20 : [num_users=1] = call_module[target=activation_post_process_20](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%activation_post_process_20,), kwargs = {})
    %activation_post_process_21 : [num_users=1] = call_module[target=activation_post_process_21](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_21, %activation_post_process_19), kwargs = {})
    %layer3_1_relu_1 : [num_users=1] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %activation_post_process_22 : [num_users=2] = call_module[target=activation_post_process_22](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%activation_post_process_22,), kwargs = {})
    %activation_post_process_23 : [num_users=1] = call_module[target=activation_post_process_23](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%activation_post_process_23,), kwargs = {})
    %activation_post_process_24 : [num_users=1] = call_module[target=activation_post_process_24](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%activation_post_process_22,), kwargs = {})
    %activation_post_process_25 : [num_users=1] = call_module[target=activation_post_process_25](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_24, %activation_post_process_25), kwargs = {})
    %layer4_0_relu_1 : [num_users=1] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %activation_post_process_26 : [num_users=2] = call_module[target=activation_post_process_26](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%activation_post_process_26,), kwargs = {})
    %activation_post_process_27 : [num_users=1] = call_module[target=activation_post_process_27](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%activation_post_process_27,), kwargs = {})
    %activation_post_process_28 : [num_users=1] = call_module[target=activation_post_process_28](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [num_users=1] = call_function[target=operator.add](args = (%activation_post_process_28, %activation_post_process_26), kwargs = {})
    %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %activation_post_process_29 : [num_users=1] = call_module[target=activation_post_process_29](args = (%layer4_1_relu_1,), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%activation_post_process_29,), kwargs = {})
    %activation_post_process_30 : [num_users=1] = call_module[target=activation_post_process_30](args = (%avgpool,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%activation_post_process_30, 1), kwargs = {})
    %activation_post_process_31 : [num_users=1] = call_module[target=activation_post_process_31](args = (%flatten,), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%activation_post_process_31,), kwargs = {})
    %activation_post_process_32 : [num_users=1] = call_module[target=activation_post_process_32](args = (%fc,), kwargs = {})
    return activation_post_process_32

校准#

在将观测器插入模型后运行校准函数。校准的目的是运行一些代表工作负载的样本示例(例如训练数据集的样本),以便模型中的观测器能够观察到张量的统计信息,然后使用这些信息来计算量化参数。

@torch.no_grad
def calibrate(model, data_loader, num=200):
    model.eval()
    m = 0
    for k, (image, _) in tqdm(enumerate(data_loader)):
        m += image.shape[0]
        if m > num:
            break
        model(image)
calibrate(prepared_model, data_loader)  # 在样本数据上运行校准
6it [00:12,  2.06s/it]

将模型转换为量化模型#

convert_fx 采用校准模型并生成量化模型。

quantized_model = convert_fx(prepared_model)
print(quantized_model.graph)
Hide code cell output
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1_input_scale_0 : [num_users=1] = get_attr[target=conv1_input_scale_0]
    %conv1_input_zero_point_0 : [num_users=1] = get_attr[target=conv1_input_zero_point_0]
    %quantize_per_tensor : [num_users=1] = call_function[target=torch.quantize_per_tensor](args = (%x, %conv1_input_scale_0, %conv1_input_zero_point_0, torch.quint8), kwargs = {})
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%quantize_per_tensor,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%conv1,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu_scale_0 : [num_users=1] = get_attr[target=layer1_0_relu_scale_0]
    %layer1_0_relu_zero_point_0 : [num_users=1] = get_attr[target=layer1_0_relu_zero_point_0]
    %add_relu : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer1_0_conv2, %maxpool, %layer1_0_relu_scale_0, %layer1_0_relu_zero_point_0), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%add_relu,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu_scale_0 : [num_users=1] = get_attr[target=layer1_1_relu_scale_0]
    %layer1_1_relu_zero_point_0 : [num_users=1] = get_attr[target=layer1_1_relu_zero_point_0]
    %add_relu_1 : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer1_1_conv2, %add_relu, %layer1_1_relu_scale_0, %layer1_1_relu_zero_point_0), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%add_relu_1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%add_relu_1,), kwargs = {})
    %layer2_0_relu_scale_0 : [num_users=1] = get_attr[target=layer2_0_relu_scale_0]
    %layer2_0_relu_zero_point_0 : [num_users=1] = get_attr[target=layer2_0_relu_zero_point_0]
    %add_relu_2 : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer2_0_conv2, %layer2_0_downsample_0, %layer2_0_relu_scale_0, %layer2_0_relu_zero_point_0), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%add_relu_2,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu_scale_0 : [num_users=1] = get_attr[target=layer2_1_relu_scale_0]
    %layer2_1_relu_zero_point_0 : [num_users=1] = get_attr[target=layer2_1_relu_zero_point_0]
    %add_relu_3 : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer2_1_conv2, %add_relu_2, %layer2_1_relu_scale_0, %layer2_1_relu_zero_point_0), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%add_relu_3,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%add_relu_3,), kwargs = {})
    %layer3_0_relu_scale_0 : [num_users=1] = get_attr[target=layer3_0_relu_scale_0]
    %layer3_0_relu_zero_point_0 : [num_users=1] = get_attr[target=layer3_0_relu_zero_point_0]
    %add_relu_4 : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer3_0_conv2, %layer3_0_downsample_0, %layer3_0_relu_scale_0, %layer3_0_relu_zero_point_0), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%add_relu_4,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu_scale_0 : [num_users=1] = get_attr[target=layer3_1_relu_scale_0]
    %layer3_1_relu_zero_point_0 : [num_users=1] = get_attr[target=layer3_1_relu_zero_point_0]
    %add_relu_5 : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer3_1_conv2, %add_relu_4, %layer3_1_relu_scale_0, %layer3_1_relu_zero_point_0), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%add_relu_5,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%add_relu_5,), kwargs = {})
    %layer4_0_relu_scale_0 : [num_users=1] = get_attr[target=layer4_0_relu_scale_0]
    %layer4_0_relu_zero_point_0 : [num_users=1] = get_attr[target=layer4_0_relu_zero_point_0]
    %add_relu_6 : [num_users=2] = call_function[target=torch.ops.quantized.add_relu](args = (%layer4_0_conv2, %layer4_0_downsample_0, %layer4_0_relu_scale_0, %layer4_0_relu_zero_point_0), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%add_relu_6,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu_scale_0 : [num_users=1] = get_attr[target=layer4_1_relu_scale_0]
    %layer4_1_relu_zero_point_0 : [num_users=1] = get_attr[target=layer4_1_relu_zero_point_0]
    %add_relu_7 : [num_users=1] = call_function[target=torch.ops.quantized.add_relu](args = (%layer4_1_conv2, %add_relu_6, %layer4_1_relu_scale_0, %layer4_1_relu_zero_point_0), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%add_relu_7,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    %dequantize_32 : [num_users=1] = call_method[target=dequantize](args = (%fc,), kwargs = {})
    return dequantize_32

评估#

打印量化模型的大小和精度。

print(f"量化前模型大小: {size_of_model(float_model)/(1<<20)} MB")
print(f"量化后模型大小: {size_of_model(quantized_model)/(1<<20)} MB")
量化前模型大小: 44.658939361572266 MB
量化后模型大小: 11.283956527709961 MB
top1, top5 = evaluate(quantized_model, data_loader_test)
print(f"[序列化前] 测试数据集的准确性评估: {top1.avg: 2.2g}, {top5.avg: 2.2g}")
fx_graph_mode_model_file_path = saved_model_dir + "resnet18_fx_graph_mode_quantized.pth"
# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path)

# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path))

# 保存 script 模型
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)
top1, top5 = evaluate(loaded_quantized_model, data_loader_test)
print(f"[序列化后] 测试数据集的准确性评估: {top1.avg: 2.5g}, {top5.avg: 2.5g}")
100%|██████████| 1000/1000 [11:21<00:00,  1.47it/s]
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
[序列化前] 测试数据集的准确性评估:  69,  89
100%|██████████| 1000/1000 [10:45<00:00,  1.55it/s]
[序列化后] 测试数据集的准确性评估:  69.466,  88.942

如果想获得更好的精度或性能,请尝试更改 qconfig_mapping

调试量化模型#

还可以打印量化的非量化卷积运算的权值来查看差异,首先显式调用 fuse 来融合卷积和模型中的批处理范数:注意 fuse_fx 只在 eval 模式下工作。

fused = fuse_fx(float_model)
conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]
print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))
tensor(0.0007, grad_fn=<MaxBackward1>)

与基线浮点模型和 eager 模式量化的比较#

scripted_float_model_file = "resnet18_scripted.pth"
print(f"baseline 模型大小: {size_of_model(float_model)/(1<<20)} MB")
top1, top5 = evaluate(float_model, data_loader_test)
print(f"baseline 浮点模型评估精度: {top1.avg: 2.5g}, {top5.avg: 2.5g}")
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)
baseline 模型大小: 44.658939361572266 MB
100%|██████████| 1000/1000 [16:48<00:00,  1.01s/it]
baseline 浮点模型评估精度:  69.758,  89.078

在本节中,比较了用 FX 图模式量化的模型和用 eager 模式量化的模型。FX 图模式和 eager 模式产生非常相似的量化模型,因此期望精度和加速也是相似的。

print(f"FX 图模式量化模型大小: {size_of_model(quantized_model)/(1<<20)} MB")
top1, top5 = evaluate(quantized_model, data_loader_test)
print(f"FX 图模式量化评估精度: {top1.avg: 2.5g}, {top5.avg: 2.5g}")

from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print(f"eager 模式量化模型大小: {size_of_model(eager_quantized_model)/(1<<20)} MB")
top1, top5 = evaluate(eager_quantized_model, data_loader_test)
print(f"eager 模式量化评估精度: {top1.avg: 2.5g}, {top5.avg: 2.5g}")
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)
FX 图模式量化模型大小: 11.283956527709961 MB
100%|██████████| 1000/1000 [12:50<00:00,  1.30it/s]
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1`. You can also use `weights=ResNet18_QuantizedWeights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
FX 图模式量化评估精度:  69.466,  88.942
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/_utils.py:361: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
Size of eager mode quantized model
eager 模式量化模型大小: 0.0039119720458984375 MB
100%|██████████| 1000/1000 [15:48<00:00,  1.05it/s]
eager 模式量化评估精度:  69.498,  88.882

可以看到,FX 图模式和 eager 模式量化模型的模型大小和精度非常相似。

正如我们所看到的,对于 resnet18, FX 图模式和 eager 模式量化模型都比浮点模型得到了相似的加速,比浮点模型快 2-4 倍左右。但是浮点模型的实际加速可能会根据模型、设备、构建、输入批大小、线程等而有所不同。