torch.ao.quantization.stubs 源代码


from torch import nn

[文档]class QuantStub(nn.Module): r"""Quantize stub module, before calibration, this is same as an observer, it will be swapped as `nnq.Quantize` in `convert`. Args: qconfig: quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules """ def __init__(self, qconfig=None): super(QuantStub, self).__init__() if qconfig: self.qconfig = qconfig def forward(self, x): return x
[文档]class DeQuantStub(nn.Module): r"""Dequantize stub module, before calibration, this is same as identity, this will be swapped as `nnq.DeQuantize` in `convert`. """ def __init__(self): super(DeQuantStub, self).__init__() def forward(self, x): return x
[文档]class QuantWrapper(nn.Module): r"""A wrapper class that wraps the input module, adds QuantStub and DeQuantStub and surround the call to module with call to quant and dequant modules. This is used by the `quantization` utility functions to add the quant and dequant modules, before `convert` function `QuantStub` will just be observer, it observes the input tensor, after `convert`, `QuantStub` will be swapped to `nnq.Quantize` which does actual quantization. Similarly for `DeQuantStub`. """ quant: QuantStub dequant: DeQuantStub module: nn.Module def __init__(self, module): super(QuantWrapper, self).__init__() qconfig = module.qconfig if hasattr(module, 'qconfig') else None self.add_module('quant', QuantStub(qconfig)) self.add_module('dequant', DeQuantStub()) self.add_module('module', module) self.train(module.training) def forward(self, X): X = self.quant(X) X = self.module(X) return self.dequant(X)