解读 FoldScaleAxis#
/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass BackwardFoldScaleAxis();
/*!
* \brief Forward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass ForwardFoldScaleAxis();
/*!
* \brief A sequential pass that executes ForwardFoldScaleAxis and
* BackwardFoldScaleAxis passes.
*
* \return The pass.
*/
TVM_DLL Pass FoldScaleAxis();
这三个函数是用于在卷积/dense运算中折叠轴缩放的传递。它们分别是:
BackwardFoldScaleAxis()
:后向折叠轴缩放。这个函数的目的是在反向传播过程中,将轴缩放运算折叠到卷积/dense运算的权重中。这样可以减少计算量,提高性能。ForwardFoldScaleAxis()
:前向折叠轴缩放。这个函数的目的是在前向传播过程中,将轴缩放运算折叠到卷积/dense运算的权重中。这样可以在不改变模型输出的情况下,减少计算量,提高性能。FoldScaleAxis()
:这是一个顺序传递,它会依次执行ForwardFoldScaleAxis()
和BackwardFoldScaleAxis()
两个传递。这样可以同时优化前向和反向传播过程。
import testing
from tvm.relay.transform import FoldScaleAxis
FoldScaleAxis??
Signature: FoldScaleAxis()
Source:
def FoldScaleAxis():
"""Fold the scaling of axis into weights of conv2d/dense. This pass will
invoke both forward and backward scale folding.
Returns
-------
ret : tvm.transform.Pass
The registered pass to fold expressions.
Note
----
Internally, we will call backward_fold_scale_axis before using
forward_fold_scale_axis as backward folding targets the common conv->bn
pattern.
"""
return _ffi_api.FoldScaleAxis()
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/transform.py
Type: function
函数的内部注释说明了其工作原理:在内部,先调用 backward_fold_scale_axis
,然后再使用 forward_fold_scale_axis
。这是因为后向折叠针对的是常见的卷积->批量标准化(Conv->BN)模式。