解读 DFPatternCallback

解读 DFPatternCallback#

class DFPatternCallback;
/*!
 * \brief Base type of all dataflow pattern callbacks.
 * \sa DFPatternCallback
 */
class DFPatternCallbackNode : public Object {
 public:
  /*! \brief Pattern this callback matches */
  DFPattern pattern;
  /*! \brief Function to call when finding a matched expression */
  PackedFunc function;
  /*! \brief Require InferType to be run before the callback */
  bool require_type;
  /*! \brief Run the callback only once */
  bool rewrite_once;

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("pattern", &pattern);
    v->Visit("require_type", &require_type);
    v->Visit("rewrite_once", &rewrite_once);
  }

  static constexpr const char* _type_key = "DFPatternCallbackNode";
  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
};

DFPatternCallbackNode 类,它是所有数据流模式回调的基础类型。这个类继承自 Object 类。它包含以下成员变量:

  1. DFPattern pattern:匹配的模式。

  2. PackedFunc function:找到匹配表达式时要调用的函数。 3. bool require_type:在回调之前是否需要运行 InferType

  3. bool rewrite_once:是否只运行一次回调。

此外,这个类还包含名为 VisitAttrs 的成员函数,用于访问这些属性。最后,它还定义了静态常量字符串 _type_key,用于表示这个类的类型,以及 TVM_DECLARE_BASE_OBJECT_INFO 宏,用于声明这个类的信息。

from testing import viz_expr # 可视化 relay
import tvm
tvm.relay.dataflow_pattern.DFPatternCallback??
Init signature:
tvm.relay.dataflow_pattern.DFPatternCallback(
    require_type=False,
    rewrite_once=False,
)
Source:        
class DFPatternCallback:
    """A Callback for Pattern Rewriting.

    When rewrite is called on this DFPatternCallback, the backend will find matches for the
    pattern, call the callback function, and replace the matched expression with whatever
    the callback returns.

    Users are expect to inherit from this class and provide a "self.pattern" to match

    Parameters
    ----------
    require_type: bool
        Whether InferType is required to be run before the callback.
    rewrite_once: bool
        If True, run the callback only once.
    """

    def __init__(self, require_type=False, rewrite_once=False):
        self.pattern = None
        self.require_type = require_type
        self.rewrite_once = rewrite_once

    def rewrite(self, expr: Expr) -> Expr:
        """
        Rewrite expression with this callback

        Parameters
        ----------
        expr : tvm.relay.Expr
            The expression to rewrite.

        Returns
        -------
        result : tvm.relay.Expr
            The Expression with matched subgraphs rewritten by the callbacks.
        """
        return rewrite(self, expr)

    def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
        """
        Callback function to use when we found a match to the pattern

        Parameters
        ----------
        pre : tvm.relay.Expr
            The matching expression from the original graph.
        post : tvm.relay.Expr
            The matching expression with rewritten inputs
        node_map : tvm.ir.container.Map[DFPattern, List[Expr]]
            The map between patterns and matched expressions

        Returns
        -------
        result : tvm.relay.Expr
            The Expression with matched subgraph rewritten by the callback
        """
        raise NotImplementedError()
File:           /media/pc/data/lxw/ai/tvm/python/tvm/relay/dataflow_pattern/__init__.py
Type:           type
Subclasses:     LayerNormRewrite, DenseReshapeBiasGeluRewrite, ResNetV1Rewrite, LegalizeQnnOpForDnnl, MulticlassNMSRewrite, PostNMSTopKRewrite, ScatterRewrite, qdistilbert_rewrite, remove_empty_pad_callback, simplify_qnn_concat_in_func, ...

当在 DFPatternCallback 上调用 rewrite 时,后端将找到与模式匹配的部分,调用回调函数,并将匹配的表达式替换为回调返回的内容。

用户需要继承这个类并提供 "self.pattern" 来匹配。

参数:

  • require_type: bool 类型,表示是否需要在回调之前运行 InferType

  • rewrite_once: bool 类型,如果为 True,则只运行一次回调。

callback() 函数,用于在找到模式匹配时使用。

参数:

  • pre: tvm.relay.Expr 类型,表示原始图中的匹配表达式。

  • post: tvm.relay.Expr 类型,表示重写输入后的匹配表达式。

  • node_map: tvm.ir.container.Map[DFPattern, List[Expr]] 类型,表示模式和匹配表达式之间的映射关系。

返回值:

  • result: tvm.relay.Expr 类型,表示通过回调重写匹配子图的表达式。