Relax 量化/反量化算子模块#

此模块提供了 Relax 框架中的量化(quantize)和反量化(dequantize)算子的 Python 前端接口。这些算子用于神经网络模型量化过程中的张量类型转换,支持通道级别的量化参数设置。量化可以将高精度浮点数模型转换为低精度整数模型,以减少模型大小和加速推理。

qdq 在 TVM 的量化生态中扮演着关键角色,主要负责:

  1. 提供量化转换接口:作为 Python 前端 API,连接用户代码与 TVM 底层量化实现

  2. 支持模型压缩与加速:通过高精度到低精度的转换,实现模型大小减小和推理速度提升

  3. 保持量化精度:实现了标准的量化/反量化数学公式,保证转换过程中的精度损失最小化

# 导入必要的模块和类
import tvm
from tvm import relax, tir
from tvm.ir import Op
from tvm.script import relax as R

主要功能与数学原理#

该模块提供两个核心函数:

1. quantize 函数#

  • 功能:将浮点类型张量(如 float32)转换为整数类型张量(如 int8

  • 量化数学公式:Q_output = clamp(round(input_tensor/scale) + zero_point, out_dtype::min, out_dtype::max) 其中:

    • input_tensor/scale: 将输入数据归一化到零点附近

    • round(): 四舍五入到最近的整数

    • + zero_point: 添加零点偏移,使零点(通常是0)能够精确表示

    • clamp(): 将结果限制在目标数据类型的最小和最大值之间

  • 过程:归一化 → 四舍五入 → 添加零点偏移 → 值域裁剪

  • 典型应用:模型训练后的量化、量化感知训练中的量化运算

备注

该算子接收输入张量,并生成具有相同形状的量化输出。输入张量可以是任意形状。量化过程会将浮点数值映射到整数域,同时保留原始数据的相对关系。

relax.op.quantize?

Hide code cell output

Signature:
relax.op.quantize(
    data: tvm.ir.expr.RelaxExpr,
    scale: tvm.ir.expr.RelaxExpr,
    zero_point: tvm.ir.expr.RelaxExpr,
    axis: int = -1,
    out_dtype: str = 'int8',
)
Docstring:
Quantize op
This operator takes input and produces quantized output. The input tensor can be of any shape.
The output shape is the same as input shape.

Q_output = clamp((round(input_tensor/scale) + zero_point), out_dtype::min, out_dtype::max)

Parameters
----------
data : tvm.relax.Expr
    The input tensor to be quantized.

scale : tvm.relax.Expr
    The output scale.

zero_point : tvm.relax.Expr
    The output zero_point.

axis : int
    The channel axis for quantization. Default value is -1 which corresponds to the last axis.

out_dtype : str, optional
    The data type of the output tensor.

Returns
-------
result : tvm.relax.Expr
    The computed result.
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relax/op/qdq.py
Type:      function

参数描述:

  • data : tvm.relax.Expr 需要被量化的输入张量,通常为浮点类型

  • scale : tvm.relax.Expr 量化缩放因子,用于控制量化的精度。 当为张量时,可以实现通道级(per-channel)量化

  • zero_point : tvm.relax.Expr 量化零点,使得零点在整数域中能够精确表示。 当为张量时,可以实现通道级(per-channel)量化

  • axis : int, 可选 量化的通道轴。默认值为 -1,表示最后一个轴

  • out_dtype : str, 可选 输出张量的数据类型,默认为 "int8"

返回值:

  • result : tvm.relax.Expr 量化后的整数张量,形状与输入相同

2. dequantize 函数#

  • 功能:将整数类型张量转换回浮点类型张量

  • 数学原理output = clamp(scale * (input_tensor - zero_point), out_dtype::min, out_dtype::max) 其中:

    • input_tensor - zero_point: 减去零点偏移,恢复归一化的值

    • * scale: 乘以缩放因子,恢复原始数据范围

    • clamp(): 将结果限制在目标数据类型的最小和最大值之间

    • 减去零点偏移 → 乘以缩放因子 → 值域裁剪

  • 典型应用:量化模型推理过程中的反量化操作、量化模型与浮点模型的结果比较

备注

该算子接收量化后的输入张量,并生成具有相同形状的反量化输出。反量化过程是量化的逆运算,用于在需要时恢复原始浮点数据的近似值。

参数:

  • data : tvm.relax.Expr 需要被反量化的输入张量,通常为整数类型

  • scale : tvm.relax.Expr 量化缩放因子,必须与量化时使用的scale相同。 当为张量时,可以实现通道级(per-channel)反量化

  • zero_point : tvm.relax.Expr 量化零点,必须与量化时使用的zero_point相同。 当为张量时,可以实现通道级(per-channel)反量化

  • axis : int, 可选 反量化的通道轴。默认值为-1,表示最后一个轴

  • out_dtype : str, 可选 输出张量的数据类型,默认为"float32"

返回值#

  • result : tvm.relax.Expr 反量化后的浮点张量,形状与输入相同

使用场景与优势#

该模块主要应用于以下场景:

  1. 模型压缩:将FP32模型转换为INT8/UINT8模型,减少约75%的存储空间

  2. 推理加速:在支持整数运算的硬件(如GPU、NPU)上显著提升推理性能

  3. 低精度推理:在边缘设备等计算资源受限环境中部署深度学习模型

  4. 量化感知训练:在训练过程中模拟量化效果,提高量化后模型的精度

关键参数与使用方法#

共用参数说明#

  • data:输入张量(量化时为浮点型,反量化时为整型)

  • scale:缩放因子(可以是标量或张量,张量时实现通道级量化)

  • zero_point:零点值(可以是标量或张量)

  • axis:通道轴(默认为 -1,表示最后一维)

  • out_dtype:输出数据类型(量化默认int8,反量化默认float32)

输入输出示例#

量化算子示例#

# 输入:float32张量、scale和zero_point参数
input_tensor = relax.const([1.2, 2.3, -0.5], dtype="float32")
scale = relax.const(0.1, dtype="float32")
zero_point = relax.const(128, dtype="int32")

# 执行量化操作
quantized_tensor = relax.op.quantize(input_tensor, scale, zero_point, out_dtype="int8")
# 输出:int8张量,值约为[140, 151, 123]
quantized_tensor.show()
R.quantize(metadata["relax.expr.Constant"][0], R.const(0.10000000149011612, "float32"), R.const(128, "int32"), out_dtype="int8", axis=-1)
# Metadata omitted. Use show_meta=True in script() method to show it.

反量化算子示例#

# 输入:uint8张量、scale和zero_point参数
input_tensor = relax.const([140, 151, 123], dtype="uint8")
scale = relax.const(0.1, dtype="float32")
zero_point = relax.const(128, dtype="int32")

# 执行反量化操作
float_tensor = relax.op.dequantize(input_tensor, scale, zero_point)
# 输出:float32张量
float_tensor.show()
R.dequantize(metadata["relax.expr.Constant"][0], R.const(0.10000000149011612, "float32"), R.const(128, "int32"), out_dtype="float32", axis=-1)
# Metadata omitted. Use show_meta=True in script() method to show it.

测试量化和反量化算子的正确性#

# 创建输入变量:输入张量、缩放因子和零点
x = relax.Var("x", R.Tensor((2, 3), "float32"))
dx = relax.Var("dx", R.Tensor((2, 3), "uint8"))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "int8"))
# 验证量化算子是否返回正确的算子
assert relax.op.quantize(x, s, zp, 1, "int8").op == Op.get("relax.quantize")
# 验证反量化操作是否返回正确的算子
assert relax.op.dequantize(dx, s, zp, 1, "float32").op == Op.get("relax.dequantize")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
    # 辅助函数:检查操作的结构信息推断是否正确
    ret = bb.normalize(call)
    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)

测试量化和反量化算子的结构信息推断#

bb = relax.BlockBuilder()
# 创建输入变量
x = relax.Var("x", R.Tensor((2, 3), "float32"))
dx = relax.Var("dx", R.Tensor((2, 3), "uint8"))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "int8"))

# 检查量化操作的结构信息推断
_check_inference(
    bb, relax.op.quantize(x, s, zp, 1, "int8"), relax.TensorStructInfo((2, 3), "int8")
)
# 检查反量化操作的结构信息推断
_check_inference(
    bb,
    relax.op.dequantize(dx, s, zp, 1, "float32"),
    relax.TensorStructInfo((2, 3), "float32"),
)

测试符号形状输入下量化和反量化算子的结构信息推断#

bb = relax.BlockBuilder()
# 创建符号变量表示维度
n = tir.Var("n", "int64")
# 创建输入变量,其中第一个维度是符号变量
x = relax.Var("x", R.Tensor((n, 3), "float32"))
dx = relax.Var("dx", R.Tensor((n, 3), "int8"))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "int8"))

# 检查符号形状下量化操作的结构信息推断
_check_inference(
    bb, relax.op.quantize(x, s, zp, 1, "int8"), relax.TensorStructInfo((n, 3), "int8")
)
# 检查符号形状下反量化操作的结构信息推断
_check_inference(
    bb,
    relax.op.dequantize(dx, s, zp, 1, "float32"),
    relax.TensorStructInfo((n, 3), "float32"),
)

测试 float8_e4m3fn 数据类型下量化和反量化算子的结构信息推断#

bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((n, 3), "float32"))
dx = relax.Var("dx", R.Tensor((n, 3), "float8_e4m3fn"))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "float16"))

# 检查 float8_e4m3fn 类型的量化操作结构信息推断
_check_inference(
    bb,
    relax.op.quantize(x, s, zp, 1, "float8_e4m3fn"),
    relax.TensorStructInfo((n, 3), "float8_e4m3fn"),
)
# 检查 float8_e4m3fn 类型的反量化操作结构信息推断
_check_inference(
    bb,
    relax.op.dequantize(dx, s, zp, 1, "float32"),
    relax.TensorStructInfo((n, 3), "float32"),
)

测试 float8_e5m2 数据类型下量化和反量化算子的结构信息推断#

dtype = "float8_e5m2"
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((n, 3), "float32"))
dx = relax.Var("dx", R.Tensor((n, 3), dtype))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "float16"))

# 检查 float8_e5m2 类型的量化操作结构信息推断
_check_inference(
    bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n, 3), dtype)
)
# 检查 float8_e5m2 类型的反量化操作结构信息推断
_check_inference(
    bb,
    relax.op.dequantize(dx, s, zp, 1, "float32"),
    relax.TensorStructInfo((n, 3), "float32"),
)