量化神经网络

量化神经网络#

QNN(Quantized Neural Network)算子模块提供了模拟量化和反量化算子的函数,这些函数可以在保持原始数据类型的情况下模拟量化神经网络的行为,支持动态数据类型选择和通道级别的量化参数设置。

import tvm
from tvm import te, tir, topi
# 量化状态常量定义
SQNN_DISABLE = 0  # 禁用量化,直接通过原始值
SQNN_INT8 = 1     # 模拟int8量化
SQNN_UINT8 = 2    # 模拟uint8量化
SQNN_INT32 = 3    # 模拟int32量化

# 数据类型字符串到枚举值的映射
SQNN_DTYPE_TO_CODE = {
    "disable": SQNN_DISABLE,
    "int8": SQNN_INT8,
    "uint8": SQNN_UINT8,
    "int32": SQNN_INT32,
}

# 枚举值到数据类型字符串的反向映射
SQNN_CODE_TO_DTYPE = {v: k for k, v in SQNN_DTYPE_TO_CODE.items()}
@tvm.te.tag_scope(tag=topi.tag.ELEMWISE)  # 标记为逐元素操作,便于优化
def simulated_quantize(data, out_dtype, output_scale=None, output_zero_point=None, axis=-1):
    """
    模拟QNN量化操作,在不改变数据类型的情况下模拟量化输出
    
    此操作相比真实QNN量化的优势在于:
    1. 支持动态数据类型选择
    2. 可以处理通道级和标量级的scale和zero_point参数
    3. 而真实QNN量化要求这些参数在编译时固定
    
    量化数学公式:
    Q_output = clip(round(input_tensor/output_scale) + output_zero_point, out_dtype::min, out_dtype::max)

    参数
    ----------
    data: tvm.te.Tensor
        输入张量,可以是任意维度

    out_dtype: tvm.te.Tensor
        一个标量变量,指示要模拟的量化数据类型。使用SQNN_DTYPE_TO_CODE将字符串转换为对应的值

    output_scale: tvm.te.Tensor, optional
        用于量化到整数类型的缩放因子张量。
        当包含多个值时,N必须与数据中的通道数匹配

    output_zero_point: tvm.te.Tensor, optional
        用于量化到整数类型的零点张量。
        当包含多个值时,N必须与数据中的通道数匹配

    axis: int, optional
        量化的通道轴。默认值为-1,表示最后一个轴

    返回
    -------
    tvm.te.Tensor
        形状与输入相同的张量,包含模拟量化后的值
    """

    # 当禁用量化时,直接返回输入值
    def _compute_pass_through(value, *indices):
        return value[indices]

    # 模拟任意整数类型的量化计算
    # 计算式:Q_output = clip(round(input_tensor/output_scale) + output_zero_point, min, max)
    def _compute_intn(dtype, value, *indices):
        assert output_scale is not None and output_zero_point is not None
        const_min = tvm.tir.min_value(dtype)  # 获取目标数据类型的最小值
        const_max = tvm.tir.max_value(dtype)  # 获取目标数据类型的最大值
        
        # 使用indexmod处理标量和通道级的QNN参数
        # 当参数是标量时,indexmod确保始终访问第一个元素
        scale_idx = tir.indexmod(indices[axis], topi.shape(output_scale)[0])
        zp_idx = tir.indexmod(indices[axis], topi.shape(output_zero_point)[0])
        
        # 执行量化计算:取整 -> 加上零点 -> 裁剪到目标类型范围内
        return te.max(
            te.min(
                te.round(value[indices] / output_scale[scale_idx]) + output_zero_point[zp_idx],
                const_max,
            ),
            const_min,
        )

    # 使用条件链动态选择适当的量化方法
    # 这种设计允许一次编译但使用不同的量化方法,只需要通过变量数据类型输入控制
    def _dispatch_sim_quantize(value):
        # 创建直通计算(禁用量化时使用)
        pass_through_value = te.compute(
            data.shape, lambda *indices: _compute_pass_through(value, *indices)
        )
        
        # 处理int8量化
        int8_value = te.compute(
            data.shape,
            lambda *indices: tir.if_then_else(
                out_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]),
                _compute_intn("int8", value, *indices),
                pass_through_value[indices],
            ),
        )
        
        # 处理uint8量化
        uint8_value = te.compute(
            data.shape,
            lambda *indices: tir.if_then_else(
                out_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]),
                _compute_intn("uint8", value, *indices),
                int8_value[indices],
            ),
        )
        
        # 处理int32量化
        int32_value = te.compute(
            data.shape,
            lambda *indices: tir.if_then_else(
                out_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]),
                _compute_intn("int32", value, *indices),
                uint8_value[indices],
            ),
        )

        return int32_value

    # 构建最终的计算结果
    return te.compute(data.shape, lambda *indices: _dispatch_sim_quantize(data)[indices])

@tvm.te.tag_scope(tag=topi.tag.ELEMWISE)  # 标记为逐元素操作,便于优化
def simulated_dequantize(data, in_dtype, input_scale=None, input_zero_point=None, axis=-1):
    """
    模拟QNN反量化操作,在不改变数据类型的情况下模拟反量化输出
    
    此操作相比真实QNN反量化的优势在于:
    1. 支持动态数据类型选择
    2. 可以处理通道级和标量级的scale和zero_point参数
    3. 而真实QNN反量化要求这些参数在编译时固定
    
    反量化数学公式:
    DQ_output = (input - zero_point) * scale

    参数
    ----------
    data: tvm.te.Tensor
        输入张量,可以是任意维度

    in_dtype: tvm.te.Tensor
        一个标量变量,指示要模拟的反量化数据类型。使用SQNN_DTYPE_TO_CODE将字符串转换为对应的值

    input_scale: tvm.te.Tensor, optional
        用于从整数类型反量化的缩放因子张量。
        当包含多个值时,N必须与数据中的通道数匹配

    input_zero_point: tvm.te.Tensor, optional
        用于从整数类型反量化的零点张量。
        当包含多个值时,N必须与数据中的通道数匹配

    axis: int, optional
        量化的通道轴。默认值为-1,表示最后一个轴

    返回
    -------
    tvm.te.Tensor
        形状与输入相同的张量,包含模拟反量化后的值
    """

    # 当禁用反量化时,直接返回输入张量
    def _compute_pass_through(value, *indices):
        return value[indices]

    # 模拟任意整数类型的反量化计算
    # 计算式:DQ_output = (input - zero_point) * scale
    def _compute_intn(value, *indices):
        assert input_scale is not None and input_zero_point is not None
        
        # 使用indexmod处理标量和通道级的QNN参数
        scale_idx = tir.indexmod(indices[axis], topi.shape(input_scale)[0])
        zp_idx = tir.indexmod(indices[axis], topi.shape(input_zero_point)[0])
        
        # 执行反量化计算:减去零点 -> 乘以缩放因子
        return (value[indices] - input_zero_point[zp_idx]) * input_scale[scale_idx]

    # 使用条件链动态选择适当的反量化方法
    # 这种设计允许一次编译但使用不同的反量化方法,只需要通过变量数据类型输入控制
    def _dispatch_sim_dequantize(value):
        # 创建直通计算(禁用反量化时使用)
        pass_through_value = te.compute(
            data.shape, lambda *indices: _compute_pass_through(value, *indices)
        )
        
        # 创建一个条件,检查是否为任意整数类型的量化
        intn_condition = tvm.te.any(
            in_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]),
            in_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]),
            in_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]),
        )
        
        # 根据条件选择是否执行反量化计算
        intn_value = te.compute(
            data.shape,
            lambda *indices: tir.if_then_else(
                intn_condition,
                _compute_intn(value, *indices),
                pass_through_value[indices],
            ),
        )

        return intn_value

    # 构建最终的计算结果
    return te.compute(data.shape, lambda *indices: _dispatch_sim_dequantize(data)[indices])