tvm::relay::transformForwardRewrite

tvm::relay::transformForwardRewrite#

源码:tvm/include/tvm/relay/transform.h

/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order. This
 * function is used as a helper function to rewrtie an expression in a pass.
 *
 * \param expr The expression.
 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
 *                              rule function.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 * \return The rewritten expression.
 */
TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
                            std::function<ObjectRef(const Call&)> fcontext = nullptr,
                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

ForwardRewrite 函数,它接受四个参数:

  1. expr:表达式节点。

  2. rewrite_map_attr_name:字符串,表示与重写规则函数对应的 Op 的属性名。

  3. fcontext:回调函数,用于为每个调用节点提供上下文参数。默认值为 nullptr

  4. fmulti_ref_trigger:变换函数,当表达式被多个调用者使用时调用。默认值为 nullptr

函数的返回值:重写后的表达式。

函数的主要功能是在后序 DFS 顺序下应用重写规则来重写表达式。这个函数用作辅助函数,在 pass 中重写表达式。

/*!
 * \brief Apply rewrite rules to rewrite the expr in post DFS order. This
 * function is used as a helper function to rewrtie an expression in a pass.
 *
 * \param expr The expression.
 * \param rewrite_func The rewrite func that will apply to all operators.
 * \param fcontext Additional callback to provide context argument for each call node.
 * \param fmulti_ref_trigger Transformation function to be called when
 *                           an Expr consumed by multiple callers.
 *
 * \return The rewritten expression.
 */
TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
                            std::function<ObjectRef(const Call&)> fcontext = nullptr,
                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

ForwardRewrite 的另一个函数,用于在后序 DFS 顺序下应用重写规则来重写表达式。

接受四个参数:

  1. expr:表达式。

  2. rewrite_func:重写函数,将应用于所有算子。

  3. fcontext:回调函数,用于为每个调用节点提供上下文参数。默认值为 nullptr

  4. fmulti_ref_trigger:变换函数,当表达式被多个调用者使用时调用。默认值为 nullptr

两个函数的返回值都是重写后的表达式。

这两个ForwardRewrite函数的主要区别在于它们使用的重写规则的形式。具体分析如下:

  1. 第一个 ForwardRewrite 函数

    • 参数差异:这个函数接受一个字符串参数 rewrite_map_attr_name,它表示与重写规则函数对应的 Op 的属性名。

    • 使用场景:当重写规则与特定的算子属性相关联时,这种形式很有用,例如,根据某些属性值选择不同的重写规则。

  2. 第二个 ForwardRewrite 函数

    • 参数差异:这个函数接受 FForwardRewrite 类型的参数 rewrite_func,这是一个函数对象,将应用于所有算子。

    • 使用场景:当有一系列通用的重写规则需要应用于所有算子时,这种形式更为合适。这提供了一种更灵活的方法来定义和应用程序广泛的重写逻辑。

  3. 功能对比

    • 灵活性对比:使用 FForwardRewrite 对象的版本提供了更大的灵活性,因为它允许传递一个可以处理各种算子的函数对象。

    • 适用性对比:如果重写逻辑依赖于特定算子的属性,则使用属性名的版本可能更直接和方便。

  4. 以下情况使用第一个函数

    • 当重写逻辑与特定算子的属性紧密相关,且这些属性定义了如何进行重写时。

    • 当需要在运行时根据属性值选择不同的重写策略时。

  5. 以下情况使用第二个函数

    • 当有一组通用的重写规则需要统一应用于所有算子时。

    • 当重写逻辑不依赖于特定算子的属性,而是依赖于算子本身或其他上下文信息时。

总的来说,在选择使用哪个 ForwardRewrite 函数时,应考虑重写逻辑的具体需求和应用场景。如果重写规则高度依赖于特定操作的属性,第一个函数可能更合适;如果需要更通用的重写逻辑,第二个函数提供了更大的灵活性。