解读 FoldScaleAxis

解读 FoldScaleAxis#

/*!
 * \brief Backward fold axis scaling into weights of conv/dense operators.
 *
 * \return The pass.
 */
TVM_DLL Pass BackwardFoldScaleAxis();

/*!
 * \brief Forward fold axis scaling into weights of conv/dense operators.
 *
 * \return The pass.
 */
TVM_DLL Pass ForwardFoldScaleAxis();

/*!
 * \brief A sequential pass that executes ForwardFoldScaleAxis and
 * BackwardFoldScaleAxis passes.
 *
 * \return The pass.
 */
TVM_DLL Pass FoldScaleAxis();

这三个函数是用于在卷积/dense运算中折叠轴缩放的传递。它们分别是:

  1. BackwardFoldScaleAxis():后向折叠轴缩放。这个函数的目的是在反向传播过程中,将轴缩放运算折叠到卷积/dense运算的权重中。这样可以减少计算量,提高性能。

  2. ForwardFoldScaleAxis():前向折叠轴缩放。这个函数的目的是在前向传播过程中,将轴缩放运算折叠到卷积/dense运算的权重中。这样可以在不改变模型输出的情况下,减少计算量,提高性能。

  3. FoldScaleAxis():这是一个顺序传递,它会依次执行 ForwardFoldScaleAxis()BackwardFoldScaleAxis() 两个传递。这样可以同时优化前向和反向传播过程。

import testing
from tvm.relay.transform import FoldScaleAxis
FoldScaleAxis??
Signature: FoldScaleAxis()
Source:   
def FoldScaleAxis():
    """Fold the scaling of axis into weights of conv2d/dense. This pass will
    invoke both forward and backward scale folding.

    Returns
    -------
    ret : tvm.transform.Pass
        The registered pass to fold expressions.

    Note
    ----
    Internally, we will call backward_fold_scale_axis before using
    forward_fold_scale_axis as backward folding targets the common conv->bn
    pattern.
    """
    return _ffi_api.FoldScaleAxis()
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/transform.py
Type:      function

函数的内部注释说明了其工作原理:在内部,先调用 backward_fold_scale_axis,然后再使用 forward_fold_scale_axis。这是因为后向折叠针对的是常见的卷积->批量标准化(Conv->BN)模式。