
# 提供注解的向前兼容
from __future__ import annotations



将模型转换为 可量化模型 的流程如下:

  1. FloatFunctional 替换加法

  2. 使用 fuse_modules() 或者 fuse_modules_qat() 融合如下模块序列:

    • conv, bn

    • conv, bn, relu

    • conv, relu

    • linear, bn

    • linear, relu

  3. 在网络的开头和结尾分别插入 QuantStubDeQuantStub

  4. torch.nn.ReLU6 替换为 torch.nn.ReLU


'''参考 torchvision/models/quantization/mobilenetv2.py
from typing import Any
from torch import Tensor
from torch import nn

from torchvision._internally_replaced_utils import load_state_dict_from_url
from torchvision.ops.misc import ConvNormActivation
from torchvision.models.quantization.utils import _fuse_modules, _replace_relu, quantize_model
from torch.ao.quantization import QuantStub, DeQuantStub



_fuse_modules() 提供了 fuse_modules()fuse_modules_qat() 的统一接口。

from torch.ao.quantization import fuse_modules_qat, fuse_modules

def _fuse_modules(
    model: nn.Module, 
    modules_to_fuse: list[str] | list[list[str]], 
    is_qat: bool | None, 
    **kwargs: Any
    if is_qat is None:
        is_qat = model.training
    method = fuse_modules_qat if is_qat else  fuse_modules
    return method(model, modules_to_fuse, **kwargs)

FloatFunctional 算子比普通的 torch. 的运算多了后处理操作,比如:

class FloatFunctional(torch.nn.Module):
    def __init__(self):
        self.activation_post_process = torch.nn.Identity()

    def forward(self, x):
        raise RuntimeError("FloatFunctional is not intended to use the " +
                           "'forward'. Please use the underlying operation")

    r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
    def add(self, x: Tensor, y: Tensor) -> Tensor:
        r = torch.add(x, y)
        r = self.activation_post_process(r)
        return r

由于 self.activation_post_process = torch.nn.Identity() 是自映射,所以 add() 等价于 torch.add()


猜测 FloatFunctional 算子提供了自定义算子的官方接口。即只需要对 self.activation_post_process 赋值即可添加算子的后处理工作。

特定于 MobileNetV2 的量化#

下面以 MobileNetV2 为例,介绍如何将其转换为量化模型 QuantizableMobileNetV2

from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls

quant_model_urls = {
    "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"

class QuantizableInvertedResidual(InvertedResidual):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x: Tensor) -> Tensor:
        if self.use_res_connect:
            return self.skip_add.add(x, self.conv(x))
            return self.conv(x)

    def fuse_model(self, is_qat: bool | None = None) -> None:
        for idx in range(len(self.conv)):
            if type(self.conv[idx]) is nn.Conv2d:
                               str(idx + 1)],

class QuantizableMobileNetV2(MobileNetV2):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        MobileNet V2 main class

           继承自浮点 MobileNetV2 的参数
        super().__init__(*args, **kwargs)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x: Tensor) -> Tensor:
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

    def fuse_model(self, is_qat: bool | None) -> None:
        for m in self.modules():
            if type(m) is ConvNormActivation:
                _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
            if type(m) is QuantizableInvertedResidual:

def mobilenet_v2(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV2:
    从 `MobileNetV2:反向残差和线性瓶颈 <https://arxiv.org/abs/1801.04381>`_ 构建 MobileNetV2 架构。

    注意,quantize = True 返回具有 8 bit 权值的量化模型。量化模型只支持推理并在 CPU 上运行。
    目前还不支持 GPU 推理

        pretrained (bool): 如果为 True,返回在 ImageNet 上训练过的模型。
        progress (bool): 如果为 True,则显示下载到标准错误的进度条
        quantize(bool): 如果为 True,则返回量化模型,否则返回浮点模型
    model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)

    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = "qnnpack"
        quantize_model(model, backend)
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
            model_url = quant_model_urls["mobilenet_v2_" + backend]
            model_url = model_urls["mobilenet_v2"]

        state_dict = load_state_dict_from_url(model_url, progress=progress)
    return model