自定义量化
导航
自定义量化#
# 提供注解的向前兼容
from __future__ import annotations
量化流程#
重要
将模型转换为 可量化模型 的流程如下:
用
FloatFunctional
替换加法使用
fuse_modules()
或者fuse_modules_qat()
融合如下模块序列:conv, bn
conv, bn, relu
conv, relu
linear, bn
linear, relu
在网络的开头和结尾分别插入
QuantStub
和DeQuantStub
将
torch.nn.ReLU6
替换为torch.nn.ReLU
载入一些库
'''参考 torchvision/models/quantization/mobilenetv2.py
'''
from typing import Any
from torch import Tensor
from torch import nn
from torchvision._internally_replaced_utils import load_state_dict_from_url
from torchvision.ops.misc import ConvNormActivation
from torchvision.models.quantization.utils import _fuse_modules, _replace_relu, quantize_model
from torch.ao.quantization import QuantStub, DeQuantStub
融合模块
备注
_fuse_modules()
提供了 fuse_modules()
和 fuse_modules_qat()
的统一接口。
from torch.ao.quantization import fuse_modules_qat, fuse_modules
def _fuse_modules(
model: nn.Module,
modules_to_fuse: list[str] | list[list[str]],
is_qat: bool | None,
**kwargs: Any
):
if is_qat is None:
is_qat = model.training
method = fuse_modules_qat if is_qat else fuse_modules
return method(model, modules_to_fuse, **kwargs)
FloatFunctional
算子比普通的 torch.
的运算多了后处理操作,比如:
class FloatFunctional(torch.nn.Module):
def __init__(self):
super().__init__()
self.activation_post_process = torch.nn.Identity()
def forward(self, x):
raise RuntimeError("FloatFunctional is not intended to use the " +
"'forward'. Please use the underlying operation")
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = self.activation_post_process(r)
return r
由于 self.activation_post_process = torch.nn.Identity()
是自映射,所以 add()
等价于 torch.add()
。
小技巧
猜测 FloatFunctional
算子提供了自定义算子的官方接口。即只需要对 self.activation_post_process
赋值即可添加算子的后处理工作。
特定于 MobileNetV2
的量化#
下面以 MobileNetV2
为例,介绍如何将其转换为量化模型 QuantizableMobileNetV2
。
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
quant_model_urls = {
"mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"
}
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return self.skip_add.add(x, self.conv(x))
else:
return self.conv(x)
def fuse_model(self, is_qat: bool | None = None) -> None:
for idx in range(len(self.conv)):
if type(self.conv[idx]) is nn.Conv2d:
_fuse_modules(self.conv,
[str(idx),
str(idx + 1)],
is_qat,
inplace=True)
class QuantizableMobileNetV2(MobileNetV2):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
MobileNet V2 main class
Args:
继承自浮点 MobileNetV2 的参数
"""
super().__init__(*args, **kwargs)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x
def fuse_model(self, is_qat: bool | None) -> None:
for m in self.modules():
if type(m) is ConvNormActivation:
_fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
if type(m) is QuantizableInvertedResidual:
m.fuse_model(is_qat)
def mobilenet_v2(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
"""
从 `MobileNetV2:反向残差和线性瓶颈 <https://arxiv.org/abs/1801.04381>`_ 构建 MobileNetV2 架构。
注意,quantize = True 返回具有 8 bit 权值的量化模型。量化模型只支持推理并在 CPU 上运行。
目前还不支持 GPU 推理
Args:
pretrained (bool): 如果为 True,返回在 ImageNet 上训练过的模型。
progress (bool): 如果为 True,则显示下载到标准错误的进度条
quantize(bool): 如果为 True,则返回量化模型,否则返回浮点模型
"""
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "qnnpack"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["mobilenet_v2_" + backend]
else:
model_url = model_urls["mobilenet_v2"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model