量化张量#

参考:量化张量

创建量化张量#

  • 通过量化非量化的浮点张量得到量化张量

import torch

float_tensor = torch.randn(2, 2, 3)

scale, zero_point = 1e-4, 2
dtype = torch.qint32
q_per_tensor = torch.quantize_per_tensor(float_tensor, scale, zero_point, dtype)
q_per_tensor
tensor([[[ 0.1231, -1.9974,  0.2806],
         [ 0.6392, -0.2118,  0.6879]],

        [[-0.1315, -0.4067,  0.5414],
         [-0.4595, -1.7321, -0.5273]]], size=(2, 2, 3), dtype=torch.qint32,
       quantization_scheme=torch.per_tensor_affine, scale=0.0001, zero_point=2)

还支持逐通道量化:

scales = torch.tensor([1e-1, 1e-2, 1e-3])
zero_points = torch.tensor([-1, 0, 1])
channel_axis = 2
q_per_channel = torch.quantize_per_channel(float_tensor,
                                           scales,
                                           zero_points,
                                           axis=channel_axis,
                                           dtype=dtype)
q_per_channel
tensor([[[ 0.1000, -2.0000,  0.2810],
         [ 0.6000, -0.2100,  0.6880]],

        [[-0.1000, -0.4100,  0.5410],
         [-0.5000, -1.7300, -0.5270]]], size=(2, 2, 3), dtype=torch.qint32,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.1000, 0.0100, 0.0010], dtype=torch.float64),
       zero_point=tensor([-1,  0,  1]), axis=2)
  • 直接从 empty_quantized 函数创建量化张量

注意,_empty_affine_quantized 是一个私有 API,我们将用类似 torch 的方式替换它。将来使用 empty_quantized_tensor(sizes, quantizer)

q = torch._empty_affine_quantized([10],
                                  scale=scale,
                                  zero_point=zero_point,
                                  dtype=dtype)
q
tensor([-0.0002, -0.0002, -0.0002, -0.0002,  0.0062, -0.0002,  0.0078, -0.0002,
        -0.0002, -0.0002], size=(10,), dtype=torch.qint32,
       quantization_scheme=torch.per_tensor_affine, scale=0.0001, zero_point=2)
  • 通过集合 int 张量和量化参数来创建量化张量

备注

注意,_per_tensor_affine_qtensor 是私有 API,我们将用类似 torch 的东西 torch.form_tensor(int_tensor, quantizer) 替换它

int_tensor = torch.randint(0, 100, size=(10,), dtype=torch.uint8)

数据类型为 torch.quint8,即对应的 torch.uint8,我们有以下对应的 torch int 类型和 torch 量化 int 类型:

  • torch.uint8 -> torch.quint8

  • torch.int8 -> torch.qint8

  • torch.int32 -> torch.qint32

q = torch._make_per_tensor_quantized_tensor(int_tensor, scale, zero_point)  # Note no `dtype`
q 
tensor([ 6.4000e-03,  9.3000e-03,  3.7000e-03,  2.3000e-03, -1.0000e-04,
         6.9000e-03,  9.2000e-03,  4.1000e-03,  1.1000e-03,  4.6000e-03],
       size=(10,), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0001, zero_point=2)

在当前的 API 中,我们必须专一每个量化方案的函数,例如,如果我们想量化张量,我们将有 quantize_per_tensorquantize_per_channel。类似地,对于 q_scaleq_zero_point,我们应该有以 Quantizer 作为参数的单一量化函数。为了检查量化参数,我们应该让量化张量返回 Quantizer 对象,这样我们就可以在 Quantizer 对象上检查量化参数,而不是把所有东西都放到张量 API 中。当前的基础设施还没有为这种支持做好准备,目前正在开发中。

量化张量的运算#

反量化

dequantized_tensor = q.dequantize()
dequantized_tensor
tensor([ 6.4000e-03,  9.3000e-03,  3.7000e-03,  2.3000e-03, -1.0000e-04,
         6.9000e-03,  9.2000e-03,  4.1000e-03,  1.1000e-03,  4.6000e-03])

支持切片

量化张量像通常的张量一样支持切片:

s = q[2]
s
tensor(0.0037, size=(), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0001, zero_point=2)

备注

尺度(scale)和零点(zero_point)相同的量化张量,它包含与 q_made_per_tensor[2, :] 相同的原始量化张量的第二行值。

赋值

q[0] = 3.5 # 量化 3.5 并将 int 值存储在量化张量中

拷贝

我们可以从量化张量复制相同大小和 dtype 但不同尺度和零点的张量:

scale1, zero_point1 = 1e-1, 0
scale2, zero_point2 = 1, -1
q1 = torch._empty_affine_quantized([2, 3],
                                   scale=scale1,
                                   zero_point=zero_point1,
                                   dtype=torch.qint8)
q2 = torch._empty_affine_quantized([2, 3],
                                   scale=scale2,
                                   zero_point=zero_point2,
                                   dtype=torch.qint8)
q2.copy_(q1)
tensor([[-1.6000,  4.0000,  7.0000],
        [ 3.3000,  6.6000,  8.6000]], size=(2, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=0)

Permutation

q1.transpose(0, 1)  # see https://pytorch.org/docs/stable/torch.html#torch.transpose
q1.permute([1, 0])  # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.permute
q1.contiguous()  # Convert to contiguous Tensor
tensor([[-1.6000,  4.0000,  7.0000],
        [ 3.3000,  6.6000,  8.6000]], size=(2, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=0)

序列化与反序列化

import tempfile
with tempfile.NamedTemporaryFile() as f:
    torch.save(q2, f)
    f.seek(0)
    q3 = torch.load(f)

检查量化张量#

# Check size of Tensor
q.numel(), q.size()
(10, torch.Size([10]))
# Check whether the tensor is quantized
q.is_quantized
True
# Get the scale of the quantized Tensor, only works for affine quantized tensor
q.q_scale()
0.0001
# Get the zero_point of quantized Tensor
q.q_zero_point()
2
# get the underlying integer representation of the quantized Tensor
# int_repr() returns a Tensor of the corresponding data type of the quantized data type
# e.g.for quint8 Tensor it returns a uint8 Tensor while preserving the MemoryFormat when possible
q.int_repr()
tensor([66, 95, 39, 25,  1, 71, 94, 43, 13, 48], dtype=torch.uint8)
# If a quantized Tensor is a scalar we can print the value:
# item() will dequantize the current tensor and return a Scalar of float
q[0].item()
0
# printing
print(q)
tensor([ 6.4000e-03,  9.3000e-03,  3.7000e-03,  2.3000e-03, -1.0000e-04,
         6.9000e-03,  9.2000e-03,  4.1000e-03,  1.1000e-03,  4.6000e-03],
       size=(10,), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0001, zero_point=2)
# indexing
print(q[0]) # q[0] is a quantized Tensor with one value
tensor(0.0064, size=(), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0001, zero_point=2)

量化的算子/内核#

我们也在研究量化算子,如量化 QReluQAddQCatQLinearQConv 等。我们要么使用简单的操作符实现,要么在操作符中封装 fbgemm 实现。所有的操作员都是在 C10 中注册的,而且他们现在只在 CPU 中。我们也有关于如何写量化算子/内核的说明

量化模型#

我们还有量化的模块,它们封装了这些内核实现,这些内核实现位于 torch.nn.quantized 命名空间中,将在模型开发中使用。我们将提供实用函数来将 torch.nn.Module 替换为 torch.nn.quantized.Module,但用户也可以自由地直接使用它们。我们会尽量将量化模块的 api 与 torch.nn.Module 中的对应 api 匹配。

torch.nn.qat
<module 'torch.nn.qat' from '/home/pc/xinet/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/nn/qat/__init__.py'>