解读 StructuralEqual
#
源码:tvm/node/structural_equal.h
BaseValueEqual
#
/*!
* \brief Equality definition of base value class.
*/
class BaseValueEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
return lhs == rhs;
}
};
BaseValueEqual
类,用于比较不同类型的基本值是否相等。该类包含多个重载的 operator()
函数,分别用于比较不同类型的参数。
具体来说,该类实现了以下类型的比较:
对于
double
类型,使用模糊浮点数比较,允许一定的误差范围(1e-9)。对于
int64_t
、uint64_t
、int
和bool
类型,直接使用等于运算符进行比较。对于
std::string
类型,也使用等于运算符进行比较。对于
DataType
类型,同样使用等于运算符进行比较。对于枚举类型,也使用等于运算符进行比较。
通过这些重载的 operator()
函数,可以方便地对不同类型的基本值进行相等性判断。
ObjectPathPairNode
和 ObjectPathPair
#
/*!
* \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
*/
class ObjectPathPairNode : public Object {
public:
ObjectPath lhs_path;
ObjectPath rhs_path;
ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);
static constexpr const char* _type_key = "ObjectPathPair";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
};
class ObjectPathPair : public ObjectRef {
public:
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
};
ObjectPathPairNode
继承自 Object
类,它表示两个对象路径的对(pair
),用于测试这两个对象的结构相等性。该类包含两个成员变量 lhs_path
和 rhs_path
,分别表示左操作数和右操作数的对象路径。构造函数接受两个 ObjectPath
类型的参数,用于初始化这两个成员变量。此外,该类还定义了静态常量字符串 _type_key
,用于标识该类的类型。
ObjectPathPair
继承自 ObjectRef
类,它表示指向 ObjectPathPairNode
对象的引用。该类包含构造函数,接受两个 ObjectPath
类型的参数,用于创建 ObjectPathPairNode
对象,并将其封装为 ObjectRef
类型。此外,该类还定义了一些方法,用于操作 ObjectPathPairNode
对象。
StructuralEqual
#
/*!
* \brief Content-aware structural equality comparator for objects.
*
* The structural equality is recursively defined in the DAG of IR nodes via SEqual.
* There are two kinds of nodes:
*
* - Graph node: a graph node in lhs can only be mapped as equal to
* one and only one graph node in rhs.
* - Normal node: equality is recursively defined without the restriction
* of graph nodes.
*
* Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
* For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
* to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
*
* A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
* with the same type if one of the following condition holds:
*
* - They appear in a same definition point(e.g. function argument).
* - They points to the same VarNode via the same_as relation.
* - They appear in a same usage point, and map_free_vars is set to be True.
*/
class StructuralEqual : public BaseValueEqual {
public:
// inheritate operator()
using BaseValueEqual::operator();
/*!
* \brief Compare objects via strutural equal.
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_params Whether or not to map free variables.
* \return The comparison result.
*/
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs,
const bool map_free_params = false) const;
};
StructuralEqual
类,继承自 BaseValueEqual
类。该类用于比较两个对象是否在结构上相等。
结构相等性是通过 IR 节点的 DAG 递归定义的,有两种类型的节点:
图节点:在
lhs
中的图节点只能映射为rhs
中的一个且仅有一个图节点。普通节点:在没有图节点限制的情况下递归定义相等性。
Vars(tir::Var, TypeVar)
和非常量 relay 表达式节点是图节点。例如,这意味着在 relay 中,%1 = %x + %y; %1 + %1
不能被结构地等于 %1 = %x + %y; %2 = %x + %y; %1 + %2
。
如果满足以下条件之一,则可以将类型相同的 var
映射为相等:
它们出现在相同的定义点(例如函数参数)。
它们通过相同的关系指向相同的
VarNode
。它们出现在相同的使用点,并且将
map_free_vars
设置为True
。
该类还重载了 operator()
方法,用于比较两个对象是否在结构上相等。该方法接受三个参数:左操作数、右操作数和布尔值 map_free_params
,表示是否映射自由变量。返回值为比较结果。
SEqualReducer
#
/*!
* \brief A Reducer class to reduce the structural equality result of two objects.
*
* The reducer will call the SEqualReduce function of each objects recursively.
* Importantly, the reducer may not directly use recursive calls to resolve the
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
class SEqualReducer {
private:
struct PathTracingData;
public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
public:
/*!
* \brief Reduce condition to equality of lhs and rhs.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
* \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) = 0;
/*!
* \brief Mark the comparison as failed, but don't fail immediately.
*
* This is useful for producing better error messages when comparing containers.
* For example, if two array sizes mismatch, it's better to mark the comparison as failed
* but compare array elements anyway, so that we could find the true first mismatch.
*/
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
/*!
* \brief Check if fail defferal is enabled.
*
* \return false if the fail deferral is not enabled, true otherwise.
*/
virtual bool IsFailDeferralEnabled() = 0;
/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
* This is an auxiliary method to check the Map<Var, Value> equality.
* \param lhs an lhs value.
*
* \return The corresponding rhs value if any, nullptr if not available.
*/
virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
/*!
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
protected:
using PathTracingData = SEqualReducer::PathTracingData;
};
/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param tracing_data Optional pointer to the path tracing data.
* \param map_free_vars Whether or not to map free variables.
*/
explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
: handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
/*!
* \brief Reduce condition to comparison of two attribute values.
*
* \param lhs The left operand.
*
* \param rhs The right operand.
*
* \param paths The paths to the LHS and RHS operands. If
* unspecified, will attempt to identify the attribute's address
* within the most recent ObjectRef. In general, the paths only
* require explicit handling for computed parameters
* (e.g. `array.size()`)
*
* \return the immediate check result.
*/
bool operator()(const double& lhs, const double& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int64_t& lhs, const int64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const std::string& lhs, const std::string& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const DataType& lhs, const DataType& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs,
Optional<ObjectPathPair> paths = NullOpt) const {
using Underlying = typename std::underlying_type<ENum>::type;
static_assert(std::is_same<Underlying, int>::value,
"Enum must have `int` as the underlying type");
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
}
template <typename T, typename Callable,
typename = std::enable_if_t<
std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
if (IsPathTracingEnabled()) {
ObjectPathPair current_paths = GetCurrentObjectPaths();
ObjectPathPair new_paths = {callable(current_paths->lhs_path),
callable(current_paths->rhs_path)};
return (*this)(lhs, rhs, new_paths);
} else {
return (*this)(lhs, rhs);
}
}
/*!
* \brief Reduce condition to comparison of two objects.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
/*!
* \brief Reduce condition to comparison of two objects.
*
* Like `operator()`, but with an additional `paths` parameter that specifies explicit object
* paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
* objects like Array and Map, or other custom objects that store nested objects that are not
* simply attributes.
*
* Can only be called when `IsPathTracingEnabled()` is `true`.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param paths Object paths for `lhs` and `rhs`.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
}
/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
*
* Call this function to compare definition points such as function params
* and var in a let-binding.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
template <typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
if (tracing_data_ == nullptr) {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
}
return true;
}
// If tracing is enabled, fall back to the regular path
const ObjectRef& lhs_obj = lhs;
const ObjectRef& rhs_obj = rhs;
return (*this)(lhs_obj, rhs_obj);
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
* \param lhs The left operand.
* \param rhs The right operand.
* \return the result.
*/
bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
// var need to be remapped, so it belongs to graph node.
handler_->MarkGraphNode();
// We only map free vars if they corresponds to the same address
// or map free_var option is set to be true.
return lhs == rhs || map_free_vars_;
}
/*! \return Get the internal handler. */
Handler* operator->() const { return handler_; }
/*! \brief Check if this reducer is tracing paths to the first mismatch. */
bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
/*!
* \brief Get the paths of the currently compared objects.
*
* Can only be called when `IsPathTracingEnabled()` is true.
*/
const ObjectPathPair& GetCurrentObjectPaths() const;
/*!
* \brief Specify the object paths of a detected mismatch.
*
* Can only be called when `IsPathTracingEnabled()` is true.
*/
void RecordMismatchPaths(const ObjectPathPair& paths) const;
private:
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
Optional<ObjectPathPair> paths = NullOpt) const;
bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair* paths) const;
static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
const void* rhs_address,
const PathTracingData* tracing_data);
template <typename T>
static bool CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data,
Optional<ObjectPathPair> paths = NullOpt);
/*! \brief Internal class pointer. */
Handler* handler_ = nullptr;
/*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
const PathTracingData* tracing_data_ = nullptr;
/*! \brief Whether or not to map free vars. */
bool map_free_vars_ = false;
};
SEqualReducer
类,它是结构相等性结果的规约器。这个类的主要目的是通过调用每个对象的 SEqualReduce
函数来递归地比较两个对象是否相等。
SEqualReducer
类包含内部处理器 Handler
,它定义了一些自定义行为的方法,如 SEqualReduce
、DeferFail
、IsFailDeferralEnabled
、MapLhsToRhs
和 MarkGraphNode
。这些方法需要在派生类中实现。
SEqualReducer
类还包含一些重载的运算符,用于比较不同类型的值,如 double
、int64_t
、uint64_t
、int
、bool
、std::string
、DataType
和枚举类型。这些运算符将调用相应的比较方法,如 EnumAttrsEqual
或 ObjectAttrsEqual
。
此外,SEqualReducer
类还提供了一些其他方法,如 operator()
重载,用于比较两个对象是否相等;DefEqual
方法,用于比较两个定义点是否相等,其中自由变量可以被映射;以及一些辅助方法,如 IsPathTracingEnabled
和 GetCurrentObjectPaths
。
SEqualHandlerDefault
#
/*! \brief The default handler for equality testing.
*
* Users can derive from this class and override the DispatchSEqualReduce method,
* to customize equality testing.
*/
class SEqualHandlerDefault : public SEqualReducer::Handler {
public:
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
bool defer_fails);
virtual ~SEqualHandlerDefault();
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) override;
void DeferFail(const ObjectPathPair& mismatch_paths) override;
bool IsFailDeferralEnabled() override;
ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
void MarkGraphNode() override;
/*!
* \brief The entry point for equality testing
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether or not to remap variables if possible.
* \return The equality result.
*/
virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars);
protected:
/*!
* \brief The dispatcher for equality testing of intermediate objects
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether or not to remap variables if possible.
* \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
* \return The equality result.
*/
virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths);
private:
class Impl;
Impl* impl;
};
这段代码定义了一个名为SEqualHandlerDefault
的类,它是SEqualReducer::Handler
类的默认处理器。这个类用于处理对象之间的相等性测试。
该类具有以下成员函数:
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch, bool defer_fails)
:构造函数,接受三个参数:assert_mode
表示是否使用断言模式,first_mismatch
是一个指向可选的对象路径对的指针,defer_fails
表示是否延迟失败。~SEqualHandlerDefault()
:析构函数。SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional<ObjectPathPair>& current_paths)
:重载的SEqualReduce
方法,用于比较两个对象是否相等。它接受四个参数:左操作数lhs
、右操作数rhs
、是否映射自由变量map_free_vars
以及当前路径对current_paths
。DeferFail(const ObjectPathPair& mismatch_paths)
:延迟失败的方法,接受一个对象路径对作为参数。IsFailDeferralEnabled()
:检查是否启用了失败延迟。MapLhsToRhs(const ObjectRef& lhs)
:将左操作数映射到右操作数的方法,接受一个对象引用作为参数。MarkGraphNode()
:标记图形节点的方法。Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars)
:入口点方法,用于进行相等性测试。它接受三个参数:左操作数lhs
、右操作数rhs
以及是否映射自由变量map_free_vars
。DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional<ObjectPathPair>& current_paths)
:中间对象的相等性测试调度器方法,接受四个参数:左操作数lhs
、右操作数rhs
、是否映射自由变量map_free_vars
以及当前路径对current_paths
。
此外,该类还包含一个私有成员变量impl
,它是一个指向内部实现类的指针。
Python 接口:structural_equal
#
def structural_equal(lhs, rhs, map_free_vars=False):
"""Check structural equality of lhs and rhs.
The structural equality is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Graph node: a graph node in lhs can only be mapped as equal to
one and only one graph node in rhs.
- Normal node: equality is recursively defined without the restriction
of graph nodes.
Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
with the same type if one of the following condition holds:
- They appear in a same definition point(e.g. function argument).
- They points to the same VarNode via the same_as relation.
- They appear in a same usage point, and map_free_vars is set to be True.
The rules for var are used to remap variables occurs in function
arguments and let-bindings.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.
Return
------
result : bool
The comparison result.
See Also
--------
structural_hash
assert_strucural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) # type: ignore # pylint: disable=no-member
structural_equal
函数,用于检查两个对象(lhs
和 rhs
)在结构上是否相等。
该函数接受三个参数:
lhs
:左操作数,表示要比较的第一个对象。rhs
:右操作数,表示要比较的第二个对象。map_free_vars
:布尔值,表示是否将自由变量(即没有定义site
的变量)映射为相等。
函数首先使用 tvm.runtime.convert
将 lhs
和 rhs
转换为相应的类型。然后,它调用 _ffi_node_api.StructuralEqual
函数来执行实际的结构相等性检查,并将结果转换为布尔值返回。
该函数的实现考虑了两种类型的节点:图节点和非图节点。图节点是指那些在 lhs
中只能映射为与 rhs
中的一个且只有一个图节点相等的节点。非图节点则通过递归定义等价性,不受到图节点的限制。
对于图节点(如 tir::Var
和 TypeVar
),如果满足以下条件之一,它们可以映射为相同类型的另一个变量相等:
它们出现在相同的定义点(例如函数参数)。
它们通过相同的关系指向相同的
VarNode
。它们出现在相同的使用点,并且
map_free_vars
设置为True
。
这些规则用于重新映射函数参数和 let
绑定中的变量。
最后,函数返回布尔值,表示比较的结果。