解读 DFPatternCallback
#
class DFPatternCallback;
/*!
* \brief Base type of all dataflow pattern callbacks.
* \sa DFPatternCallback
*/
class DFPatternCallbackNode : public Object {
public:
/*! \brief Pattern this callback matches */
DFPattern pattern;
/*! \brief Function to call when finding a matched expression */
PackedFunc function;
/*! \brief Require InferType to be run before the callback */
bool require_type;
/*! \brief Run the callback only once */
bool rewrite_once;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("require_type", &require_type);
v->Visit("rewrite_once", &rewrite_once);
}
static constexpr const char* _type_key = "DFPatternCallbackNode";
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
};
DFPatternCallbackNode
类,它是所有数据流模式回调的基础类型。这个类继承自 Object
类。它包含以下成员变量:
DFPattern pattern
:匹配的模式。PackedFunc function
:找到匹配表达式时要调用的函数。 3.bool require_type
:在回调之前是否需要运行InferType
。bool rewrite_once
:是否只运行一次回调。
此外,这个类还包含名为 VisitAttrs
的成员函数,用于访问这些属性。最后,它还定义了静态常量字符串 _type_key
,用于表示这个类的类型,以及 TVM_DECLARE_BASE_OBJECT_INFO
宏,用于声明这个类的信息。
from testing import viz_expr # 可视化 relay
import tvm
tvm.relay.dataflow_pattern.DFPatternCallback??
Init signature:
tvm.relay.dataflow_pattern.DFPatternCallback(
require_type=False,
rewrite_once=False,
)
Source:
class DFPatternCallback:
"""A Callback for Pattern Rewriting.
When rewrite is called on this DFPatternCallback, the backend will find matches for the
pattern, call the callback function, and replace the matched expression with whatever
the callback returns.
Users are expect to inherit from this class and provide a "self.pattern" to match
Parameters
----------
require_type: bool
Whether InferType is required to be run before the callback.
rewrite_once: bool
If True, run the callback only once.
"""
def __init__(self, require_type=False, rewrite_once=False):
self.pattern = None
self.require_type = require_type
self.rewrite_once = rewrite_once
def rewrite(self, expr: Expr) -> Expr:
"""
Rewrite expression with this callback
Parameters
----------
expr : tvm.relay.Expr
The expression to rewrite.
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs rewritten by the callbacks.
"""
return rewrite(self, expr)
def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
"""
Callback function to use when we found a match to the pattern
Parameters
----------
pre : tvm.relay.Expr
The matching expression from the original graph.
post : tvm.relay.Expr
The matching expression with rewritten inputs
node_map : tvm.ir.container.Map[DFPattern, List[Expr]]
The map between patterns and matched expressions
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraph rewritten by the callback
"""
raise NotImplementedError()
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/dataflow_pattern/__init__.py
Type: type
Subclasses: LayerNormRewrite, DenseReshapeBiasGeluRewrite, ResNetV1Rewrite, LegalizeQnnOpForDnnl, MulticlassNMSRewrite, PostNMSTopKRewrite, ScatterRewrite, qdistilbert_rewrite, remove_empty_pad_callback, simplify_qnn_concat_in_func, ...
当在 DFPatternCallback
上调用 rewrite
时,后端将找到与模式匹配的部分,调用回调函数,并将匹配的表达式替换为回调返回的内容。
用户需要继承这个类并提供 "self.pattern"
来匹配。
参数:
require_type
:bool
类型,表示是否需要在回调之前运行InferType
。rewrite_once
:bool
类型,如果为True
,则只运行一次回调。
callback()
函数,用于在找到模式匹配时使用。
参数:
pre
:tvm.relay.Expr
类型,表示原始图中的匹配表达式。post
:tvm.relay.Expr
类型,表示重写输入后的匹配表达式。node_map
:tvm.ir.container.Map[DFPattern, List[Expr]]
类型,表示模式和匹配表达式之间的映射关系。
返回值:
result
:tvm.relay.Expr
类型,表示通过回调重写匹配子图的表达式。