快速入门
导航
快速入门#
参考:
本教程介绍基于 torch.fx
在 graph 模式下进行训练后静态量化的步骤。FX Graph 模式量化的优点:可以在模型上完全自动地执行量化,尽管可能需要一些努力使模型与 FX Graph 模式量化兼容(象征性地用 torch.fx
跟踪),将有单独的教程来展示如何使我们想量化的模型的一部分与 FX Graph 模式量化兼容。也有 FX Graph 模式后训练动态量化 教程。FX Graph 模式 API 如下所示:
import torch
from torch.quantization import get_default_qconfig
# Note that this is temporary,
# we'll expose these functions to torch.quantization after official releasee
from torch.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
prepared_model = prepare_fx(float_model, qconfig_dict) # fuse modules and insert observers
calibrate(prepared_model, valset) # run calibration on sample data
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
FX Graph 模式量化的动机#
目前 PyTorch 存在 eager 模式量化:Static Quantization with Eager Mode in PyTorch。
可以看到,该过程涉及到多个手动步骤,包括:
显式地 quantize 和 dequantize activations,当浮点和量化运算混合在模型中时,这是非常耗时的。
显式融合模块,这需要手动识别卷积序列、 batch norms 以及 relus 和其他融合模式。
PyTorch 张量运算需要特殊处理(如
add
、concat
等)。函数式没有 first class 支持(
functional.conv2d
和functional.linear
不会被量化)
这些需要的修改大多来自于 Eager 模式量化的潜在限制。Eager 模式在模块级工作,因为它不能检查实际运行的代码(在 forward
函数中),量化是通过模块交换实现的,不知道在 Eager 模式下 forward
函数中模块是如何使用的。因此,它需要用户手动插入 QuantStub
和 DeQuantStub
,以标记他们想要 quantize 或 dequantize 的点。在图模式中,可以检查在 forward
函数中执行的实际代码(例如 aten
函数调用),量化是通过模块和 graph 操作实现的。由于图模式对运行的代码具有完全的可见性,能够自动地找出要融合哪些模块,在哪里插入 observer 调用,quantize/dequantize 函数等,能够自动化整个量化过程。
FX Graph 模式量化的优点是:
简化量化流程,最小化手动步骤
开启了进行更高级别优化的可能性,如自动精度选择(automatic precision selection)
定义辅助函数和 Prepare Dataset#
首先进行必要的导入,定义一些辅助函数并准备数据。这些步骤与 PyTorch 中 使用 Eager 模式的静态量化 相同。
要使用整个 ImageNet 数据集运行本教程中的代码,首先按照 ImageNet Data 中的说明下载 ImageNet。将下载的文件解压缩到 data_path
文件夹中。
下载 torchvision resnet18 模型
并将其重命名为 models/resnet18_pretrained_float.pth
。
from torch_book.data import ImageNet
root = "/media/pc/data/4tb/lxw/datasets/ILSVRC"
saved_model_dir = 'models/'
dataset = ImageNet(root)
trainset = dataset.loader(batch_size=30, split="train")
valset = dataset.loader(batch_size=50, split="val")
import copy
from torchvision import models
model_name = "resnet18"
float_model = getattr(models, model_name)(pretrained=True)
float_model.eval()
# deepcopy the model since we need to keep the original model around
model_to_quantize = copy.deepcopy(float_model)
评估模式的模型#
对于训练后量化,需要将模型设置为评估模式。
model_to_quantize.eval();
使用 qconfig_dict
指定如何量化模型#
qconfig_dict = {"" : default_qconfig}
使用与 Eager 模式量化中相同的 qconfig
, qconfig
只是用于激活和权重的 observers 的命名元组。qconfig_dict
是具有以下配置的字典:
qconfig = {
" : qconfig_global,
"sub" : qconfig_sub,
"sub.fc" : qconfig_fc,
"sub.conv": None
}
qconfig_dict = {
# qconfig? means either a valid qconfig or None
# optional, global config
"": qconfig?,
# optional, used for module and function types
# could also be split into module_types and function_types if we prefer
"object_type": [
(torch.nn.Conv2d, qconfig?),
(torch.nn.functional.add, qconfig?),
...,
],
# optional, used for module names
"module_name": [
("foo.bar", qconfig?)
...,
],
# optional, matched in order, first match takes precedence
"module_name_regex": [
("foo.*bar.*conv[0-9]+", qconfig?)
...,
],
# priority (in increasing order): global, object_type, module_name_regex, module_name
# qconfig == None means fusion and quantization should be skipped for anything
# matching the rule
# **api subject to change**
# optional: specify the path for standalone modules
# These modules are symbolically traced and quantized as one unit
# so that the call to the submodule appears as one call_module
# node in the forward graph of the GraphModule
"standalone_module_name": [
"submodule.standalone"
],
"standalone_module_class": [
StandaloneModuleClass
]
}
可以在 qconfig
文件 中找到与 qconfig
相关的实用函数:
from torch.quantization import get_default_qconfig, quantize_jit
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
为静态后训练量化模型做准备#
import warnings
from torch.quantization.quantize_fx import prepare_fx
warnings.filterwarnings('ignore')
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
prepare_fx
将 BatchNorm 模块折叠到前面的 Conv2d 模块中,并在模型中的适当位置插入 observers。
print(prepared_model.graph)
graph():
%x : torch.Tensor [#users=1] = placeholder[target=x]
%activation_post_process_0 : [#users=1] = call_module[target=activation_post_process_0](args = (%x,), kwargs = {})
%conv1 : [#users=1] = call_module[target=conv1](args = (%activation_post_process_0,), kwargs = {})
%activation_post_process_1 : [#users=1] = call_module[target=activation_post_process_1](args = (%conv1,), kwargs = {})
%maxpool : [#users=1] = call_module[target=maxpool](args = (%activation_post_process_1,), kwargs = {})
%activation_post_process_2 : [#users=2] = call_module[target=activation_post_process_2](args = (%maxpool,), kwargs = {})
%layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%activation_post_process_2,), kwargs = {})
%activation_post_process_3 : [#users=1] = call_module[target=activation_post_process_3](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%activation_post_process_3,), kwargs = {})
%activation_post_process_4 : [#users=1] = call_module[target=activation_post_process_4](args = (%layer1_0_conv2,), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_4, %activation_post_process_2), kwargs = {})
%layer1_0_relu_1 : [#users=1] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%activation_post_process_5 : [#users=2] = call_module[target=activation_post_process_5](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%activation_post_process_5,), kwargs = {})
%activation_post_process_6 : [#users=1] = call_module[target=activation_post_process_6](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%activation_post_process_6,), kwargs = {})
%activation_post_process_7 : [#users=1] = call_module[target=activation_post_process_7](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_7, %activation_post_process_5), kwargs = {})
%layer1_1_relu_1 : [#users=1] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%activation_post_process_8 : [#users=2] = call_module[target=activation_post_process_8](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_conv1 : [#users=1] = call_module[target=layer2.0.conv1](args = (%activation_post_process_8,), kwargs = {})
%activation_post_process_9 : [#users=1] = call_module[target=activation_post_process_9](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_conv2 : [#users=1] = call_module[target=layer2.0.conv2](args = (%activation_post_process_9,), kwargs = {})
%activation_post_process_10 : [#users=1] = call_module[target=activation_post_process_10](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [#users=1] = call_module[target=layer2.0.downsample.0](args = (%activation_post_process_8,), kwargs = {})
%activation_post_process_11 : [#users=1] = call_module[target=activation_post_process_11](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_10, %activation_post_process_11), kwargs = {})
%layer2_0_relu_1 : [#users=1] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%activation_post_process_12 : [#users=2] = call_module[target=activation_post_process_12](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_conv1 : [#users=1] = call_module[target=layer2.1.conv1](args = (%activation_post_process_12,), kwargs = {})
%activation_post_process_13 : [#users=1] = call_module[target=activation_post_process_13](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_conv2 : [#users=1] = call_module[target=layer2.1.conv2](args = (%activation_post_process_13,), kwargs = {})
%activation_post_process_14 : [#users=1] = call_module[target=activation_post_process_14](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_14, %activation_post_process_12), kwargs = {})
%layer2_1_relu_1 : [#users=1] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%activation_post_process_15 : [#users=2] = call_module[target=activation_post_process_15](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_conv1 : [#users=1] = call_module[target=layer3.0.conv1](args = (%activation_post_process_15,), kwargs = {})
%activation_post_process_16 : [#users=1] = call_module[target=activation_post_process_16](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_conv2 : [#users=1] = call_module[target=layer3.0.conv2](args = (%activation_post_process_16,), kwargs = {})
%activation_post_process_17 : [#users=1] = call_module[target=activation_post_process_17](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [#users=1] = call_module[target=layer3.0.downsample.0](args = (%activation_post_process_15,), kwargs = {})
%activation_post_process_18 : [#users=1] = call_module[target=activation_post_process_18](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_17, %activation_post_process_18), kwargs = {})
%layer3_0_relu_1 : [#users=1] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%activation_post_process_19 : [#users=2] = call_module[target=activation_post_process_19](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_conv1 : [#users=1] = call_module[target=layer3.1.conv1](args = (%activation_post_process_19,), kwargs = {})
%activation_post_process_20 : [#users=1] = call_module[target=activation_post_process_20](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_conv2 : [#users=1] = call_module[target=layer3.1.conv2](args = (%activation_post_process_20,), kwargs = {})
%activation_post_process_21 : [#users=1] = call_module[target=activation_post_process_21](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_21, %activation_post_process_19), kwargs = {})
%layer3_1_relu_1 : [#users=1] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%activation_post_process_22 : [#users=2] = call_module[target=activation_post_process_22](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_conv1 : [#users=1] = call_module[target=layer4.0.conv1](args = (%activation_post_process_22,), kwargs = {})
%activation_post_process_23 : [#users=1] = call_module[target=activation_post_process_23](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_conv2 : [#users=1] = call_module[target=layer4.0.conv2](args = (%activation_post_process_23,), kwargs = {})
%activation_post_process_24 : [#users=1] = call_module[target=activation_post_process_24](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [#users=1] = call_module[target=layer4.0.downsample.0](args = (%activation_post_process_22,), kwargs = {})
%activation_post_process_25 : [#users=1] = call_module[target=activation_post_process_25](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_24, %activation_post_process_25), kwargs = {})
%layer4_0_relu_1 : [#users=1] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%activation_post_process_26 : [#users=2] = call_module[target=activation_post_process_26](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_conv1 : [#users=1] = call_module[target=layer4.1.conv1](args = (%activation_post_process_26,), kwargs = {})
%activation_post_process_27 : [#users=1] = call_module[target=activation_post_process_27](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_conv2 : [#users=1] = call_module[target=layer4.1.conv2](args = (%activation_post_process_27,), kwargs = {})
%activation_post_process_28 : [#users=1] = call_module[target=activation_post_process_28](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [#users=1] = call_function[target=operator.add](args = (%activation_post_process_28, %activation_post_process_26), kwargs = {})
%layer4_1_relu_1 : [#users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%activation_post_process_29 : [#users=1] = call_module[target=activation_post_process_29](args = (%layer4_1_relu_1,), kwargs = {})
%avgpool : [#users=1] = call_module[target=avgpool](args = (%activation_post_process_29,), kwargs = {})
%activation_post_process_30 : [#users=1] = call_module[target=activation_post_process_30](args = (%avgpool,), kwargs = {})
%flatten : [#users=1] = call_function[target=torch.flatten](args = (%activation_post_process_30, 1), kwargs = {})
%activation_post_process_31 : [#users=1] = call_module[target=activation_post_process_31](args = (%flatten,), kwargs = {})
%fc : [#users=1] = call_module[target=fc](args = (%activation_post_process_31,), kwargs = {})
%activation_post_process_32 : [#users=1] = call_module[target=activation_post_process_32](args = (%fc,), kwargs = {})
return activation_post_process_32
校准#
将 observers 插入模型后,运行校准函数。校准的目的就是通过一些样本运行代表性的工作负载(例如样本的训练数据集)以便 observers 在模型中能够观测到张量的统计数据,以后使用这些信息来计算量化参数。
import torch
def calibrate(model, data_loader, samples=500):
model.eval()
with torch.no_grad():
k = 0
for image, _ in data_loader:
if k > samples:
break
model(image)
k += len(image)
calibrate(prepared_model, trainset) # run calibration on sample data
将模型转换为量化模型#
convert_fx
采用校准模型并产生量化模型。
from torch.quantization.quantize_fx import convert_fx
quantized_model = convert_fx(prepared_model)
print(quantized_model)
GraphModule(
(conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.03267836198210716, zero_point=0, padding=(3, 3))
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Module(
(0): Module(
(conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.019193191081285477, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.051562923938035965, zero_point=75, padding=(1, 1))
)
(1): Module(
(conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.019093887880444527, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.06979087740182877, zero_point=78, padding=(1, 1))
)
)
(layer2): Module(
(0): Module(
(conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.01557458657771349, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.050476107746362686, zero_point=68, padding=(1, 1))
(downsample): Module(
(0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), scale=0.039443813264369965, zero_point=60)
)
)
(1): Module(
(conv1): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.016193654388189316, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.05214320868253708, zero_point=68, padding=(1, 1))
)
)
(layer3): Module(
(0): Module(
(conv1): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(2, 2), scale=0.018163194879889488, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.05316956341266632, zero_point=51, padding=(1, 1))
(downsample): Module(
(0): QuantizedConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), scale=0.01836947724223137, zero_point=107)
)
)
(1): Module(
(conv1): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.013543782755732536, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.048523254692554474, zero_point=71, padding=(1, 1))
)
)
(layer4): Module(
(0): Module(
(conv1): QuantizedConvReLU2d(256, 512, kernel_size=(3, 3), stride=(2, 2), scale=0.014485283754765987, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.0517863854765892, zero_point=64, padding=(1, 1))
(downsample): Module(
(0): QuantizedConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), scale=0.04331441596150398, zero_point=58)
)
)
(1): Module(
(conv1): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.021167000755667686, zero_point=0, padding=(1, 1))
(conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.22766999900341034, zero_point=45, padding=(1, 1))
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): QuantizedLinear(in_features=512, out_features=1000, scale=0.27226582169532776, zero_point=35, qscheme=torch.per_channel_affine)
)
def forward(self, x : torch.Tensor):
conv1_input_scale_0 = self.conv1_input_scale_0
conv1_input_zero_point_0 = self.conv1_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8); x = conv1_input_scale_0 = conv1_input_zero_point_0 = None
conv1 = self.conv1(quantize_per_tensor); quantize_per_tensor = None
maxpool = self.maxpool(conv1); conv1 = None
layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_conv1); layer1_0_conv1 = None
layer1_0_relu_scale_0 = self.layer1_0_relu_scale_0
layer1_0_relu_zero_point_0 = self.layer1_0_relu_zero_point_0
add_relu = torch.ops.quantized.add_relu(layer1_0_conv2, maxpool, layer1_0_relu_scale_0, layer1_0_relu_zero_point_0); layer1_0_conv2 = maxpool = layer1_0_relu_scale_0 = layer1_0_relu_zero_point_0 = None
layer1_1_conv1 = getattr(self.layer1, "1").conv1(add_relu)
layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_conv1); layer1_1_conv1 = None
layer1_1_relu_scale_0 = self.layer1_1_relu_scale_0
layer1_1_relu_zero_point_0 = self.layer1_1_relu_zero_point_0
add_relu_1 = torch.ops.quantized.add_relu(layer1_1_conv2, add_relu, layer1_1_relu_scale_0, layer1_1_relu_zero_point_0); layer1_1_conv2 = add_relu = layer1_1_relu_scale_0 = layer1_1_relu_zero_point_0 = None
layer2_0_conv1 = getattr(self.layer2, "0").conv1(add_relu_1)
layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_conv1); layer2_0_conv1 = None
layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(add_relu_1); add_relu_1 = None
layer2_0_relu_scale_0 = self.layer2_0_relu_scale_0
layer2_0_relu_zero_point_0 = self.layer2_0_relu_zero_point_0
add_relu_2 = torch.ops.quantized.add_relu(layer2_0_conv2, layer2_0_downsample_0, layer2_0_relu_scale_0, layer2_0_relu_zero_point_0); layer2_0_conv2 = layer2_0_downsample_0 = layer2_0_relu_scale_0 = layer2_0_relu_zero_point_0 = None
layer2_1_conv1 = getattr(self.layer2, "1").conv1(add_relu_2)
layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_conv1); layer2_1_conv1 = None
layer2_1_relu_scale_0 = self.layer2_1_relu_scale_0
layer2_1_relu_zero_point_0 = self.layer2_1_relu_zero_point_0
add_relu_3 = torch.ops.quantized.add_relu(layer2_1_conv2, add_relu_2, layer2_1_relu_scale_0, layer2_1_relu_zero_point_0); layer2_1_conv2 = add_relu_2 = layer2_1_relu_scale_0 = layer2_1_relu_zero_point_0 = None
layer3_0_conv1 = getattr(self.layer3, "0").conv1(add_relu_3)
layer3_0_conv2 = getattr(self.layer3, "0").conv2(layer3_0_conv1); layer3_0_conv1 = None
layer3_0_downsample_0 = getattr(getattr(self.layer3, "0").downsample, "0")(add_relu_3); add_relu_3 = None
layer3_0_relu_scale_0 = self.layer3_0_relu_scale_0
layer3_0_relu_zero_point_0 = self.layer3_0_relu_zero_point_0
add_relu_4 = torch.ops.quantized.add_relu(layer3_0_conv2, layer3_0_downsample_0, layer3_0_relu_scale_0, layer3_0_relu_zero_point_0); layer3_0_conv2 = layer3_0_downsample_0 = layer3_0_relu_scale_0 = layer3_0_relu_zero_point_0 = None
layer3_1_conv1 = getattr(self.layer3, "1").conv1(add_relu_4)
layer3_1_conv2 = getattr(self.layer3, "1").conv2(layer3_1_conv1); layer3_1_conv1 = None
layer3_1_relu_scale_0 = self.layer3_1_relu_scale_0
layer3_1_relu_zero_point_0 = self.layer3_1_relu_zero_point_0
add_relu_5 = torch.ops.quantized.add_relu(layer3_1_conv2, add_relu_4, layer3_1_relu_scale_0, layer3_1_relu_zero_point_0); layer3_1_conv2 = add_relu_4 = layer3_1_relu_scale_0 = layer3_1_relu_zero_point_0 = None
layer4_0_conv1 = getattr(self.layer4, "0").conv1(add_relu_5)
layer4_0_conv2 = getattr(self.layer4, "0").conv2(layer4_0_conv1); layer4_0_conv1 = None
layer4_0_downsample_0 = getattr(getattr(self.layer4, "0").downsample, "0")(add_relu_5); add_relu_5 = None
layer4_0_relu_scale_0 = self.layer4_0_relu_scale_0
layer4_0_relu_zero_point_0 = self.layer4_0_relu_zero_point_0
add_relu_6 = torch.ops.quantized.add_relu(layer4_0_conv2, layer4_0_downsample_0, layer4_0_relu_scale_0, layer4_0_relu_zero_point_0); layer4_0_conv2 = layer4_0_downsample_0 = layer4_0_relu_scale_0 = layer4_0_relu_zero_point_0 = None
layer4_1_conv1 = getattr(self.layer4, "1").conv1(add_relu_6)
layer4_1_conv2 = getattr(self.layer4, "1").conv2(layer4_1_conv1); layer4_1_conv1 = None
layer4_1_relu_scale_0 = self.layer4_1_relu_scale_0
layer4_1_relu_zero_point_0 = self.layer4_1_relu_zero_point_0
add_relu_7 = torch.ops.quantized.add_relu(layer4_1_conv2, add_relu_6, layer4_1_relu_scale_0, layer4_1_relu_zero_point_0); layer4_1_conv2 = add_relu_6 = layer4_1_relu_scale_0 = layer4_1_relu_zero_point_0 = None
avgpool = self.avgpool(add_relu_7); add_relu_7 = None
flatten = torch.flatten(avgpool, 1); avgpool = None
fc = self.fc(flatten); flatten = None
dequantize_14 = fc.dequantize(); fc = None
return dequantize_14
评估#
现在可以打印量化模型的大小和精度。
from torch_book.contrib.helper import evaluate, print_size_of_model
from torch import nn
criterion = nn.CrossEntropyLoss()
print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, valset)
Size of model before quantization
模型大小(MB):46.873073 MB
Size of model after quantization
模型大小(MB):11.853109 MB
print(f"[before serilaization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
[before serilaization] Evaluation accuracy on test dataset: 69.37, 88.89
fx_graph_mode_model_file_path = saved_model_dir + f"{model_name}_fx_graph_mode_quantized.pth"
# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path)
# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path))
# save with script
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)
top1, top5 = evaluate(loaded_quantized_model, criterion, valset)
print(f"[after serialization/deserialization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
[after serialization/deserialization] Evaluation accuracy on test dataset: 69.37, 88.89
如果希望获得更好的精度或性能,请尝试更改 qconfig_dict
。
调试量化模型#
还可以打印量化的 un-quantized conv 的权重来查看区别,首先显式地调用 fuse
来融合模型中的 conv 和 bn:注意,fuse_fx
只在 eval
模式下工作。
from torch.quantization.quantize_fx import fuse_fx
fused = fuse_fx(float_model)
conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]
print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))
tensor(0.0007, grad_fn=<MaxBackward1>)
基线浮点模型和 Eager 模式量化的比较#
scripted_float_model_file = "resnet18_scripted.pth"
print("Size of baseline model")
print_size_of_model(float_model)
top1, top5 = evaluate(float_model, criterion, valset)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)
Size of baseline model
模型大小(MB):46.874273 MB
Baseline Float Model Evaluation accuracy: 69.76, 89.08
在本节中,将量化模型与 FX Graph 模式的量化模型与在 Eager 模式下量化的模型进行比较。FX Graph 模式和 Eager 模式产生的量化模型非常相似,因此期望精度和 speedup 也很相似。
print("Size of Fx graph mode quantized model")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, valset)
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print_size_of_model(eager_quantized_model)
top1, top5 = evaluate(eager_quantized_model, criterion, valset)
print("eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)
Size of Fx graph mode quantized model
模型大小(MB):11.855297 MB
FX graph mode quantized model Evaluation accuracy on test dataset: 69.37, 88.89
Size of eager mode quantized model
模型大小(MB):11.850395 MB
eager mode quantized model Evaluation accuracy on test dataset: 69.50, 88.88
可以看到 FX Graph 模式和 Eager 模式量化模型的模型大小和精度是非常相似的。
在 AIBench 中运行模型(单线程)会得到如下结果:
Scripted Float Model:
Self CPU time total: 192.48ms
Scripted Eager Mode Quantized Model:
Self CPU time total: 50.76ms
Scripted FX Graph Mode Quantized Model:
Self CPU time total: 50.63ms
可以看到,对于 resnet18, FX Graph 模式和 Eager 模式量化模型都比浮点模型获得了相似的速度,大约比浮点模型快 2-4 倍。但是浮点模型上的实际加速可能会因模型、设备、构建、输入批大小、线程等而不同。