解读 StructuralEqual#

源码:tvm/node/structural_equal.h

BaseValueEqual#

/*!
 * \brief Equality definition of base value class.
 */
class BaseValueEqual {
 public:
  bool operator()(const double& lhs, const double& rhs) const {
    // fuzzy float pt comparison
    constexpr double atol = 1e-9;
    if (lhs == rhs) return true;
    double diff = lhs - rhs;
    return diff > -atol && diff < atol;
  }

  bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
  bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
  bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
  bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
  bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
  bool operator()(const ENum& lhs, const ENum& rhs) const {
    return lhs == rhs;
  }
};

BaseValueEqual 类,用于比较不同类型的基本值是否相等。该类包含多个重载的 operator() 函数,分别用于比较不同类型的参数。

具体来说,该类实现了以下类型的比较:

  • 对于 double 类型,使用模糊浮点数比较,允许一定的误差范围(1e-9)。

  • 对于 int64_tuint64_tintbool 类型,直接使用等于运算符进行比较。

  • 对于 std::string 类型,也使用等于运算符进行比较。

  • 对于 DataType 类型,同样使用等于运算符进行比较。

  • 对于枚举类型,也使用等于运算符进行比较。

通过这些重载的 operator() 函数,可以方便地对不同类型的基本值进行相等性判断。

ObjectPathPairNodeObjectPathPair#

/*!
 * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
 */
class ObjectPathPairNode : public Object {
 public:
  ObjectPath lhs_path;
  ObjectPath rhs_path;

  ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);

  static constexpr const char* _type_key = "ObjectPathPair";
  TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
};

class ObjectPathPair : public ObjectRef {
 public:
  ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);

  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
};

ObjectPathPairNode 继承自 Object 类,它表示两个对象路径的对(pair),用于测试这两个对象的结构相等性。该类包含两个成员变量 lhs_pathrhs_path,分别表示左操作数和右操作数的对象路径。构造函数接受两个 ObjectPath 类型的参数,用于初始化这两个成员变量。此外,该类还定义了静态常量字符串 _type_key,用于标识该类的类型。

ObjectPathPair 继承自 ObjectRef 类,它表示指向 ObjectPathPairNode 对象的引用。该类包含构造函数,接受两个 ObjectPath 类型的参数,用于创建 ObjectPathPairNode 对象,并将其封装为 ObjectRef 类型。此外,该类还定义了一些方法,用于操作 ObjectPathPairNode 对象。

StructuralEqual#

/*!
 * \brief Content-aware structural equality comparator for objects.
 *
 *  The structural equality is recursively defined in the DAG of IR nodes via SEqual.
 *  There are two kinds of nodes:
 *
 *  - Graph node: a graph node in lhs can only be mapped as equal to
 *    one and only one graph node in rhs.
 *  - Normal node: equality is recursively defined without the restriction
 *    of graph nodes.
 *
 *  Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
 *  For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
 *  to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
 *
 *  A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
 *  with the same type if one of the following condition holds:
 *
 *  - They appear in a same definition point(e.g. function argument).
 *  - They points to the same VarNode via the same_as relation.
 *  - They appear in a same usage point, and map_free_vars is set to be True.
 */
class StructuralEqual : public BaseValueEqual {
 public:
  // inheritate operator()
  using BaseValueEqual::operator();
  /*!
   * \brief Compare objects via strutural equal.
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \param map_free_params Whether or not to map free variables.
   * \return The comparison result.
   */
  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs,
                          const bool map_free_params = false) const;
};

StructuralEqual 类,继承自 BaseValueEqual 类。该类用于比较两个对象是否在结构上相等。

结构相等性是通过 IR 节点的 DAG 递归定义的,有两种类型的节点:

  • 图节点:在 lhs 中的图节点只能映射为 rhs 中的一个且仅有一个图节点。

  • 普通节点:在没有图节点限制的情况下递归定义相等性。

Vars(tir::Var, TypeVar) 和非常量 relay 表达式节点是图节点。例如,这意味着在 relay 中,%1 = %x + %y; %1 + %1 不能被结构地等于 %1 = %x + %y; %2 = %x + %y; %1 + %2

如果满足以下条件之一,则可以将类型相同的 var 映射为相等:

  • 它们出现在相同的定义点(例如函数参数)。

  • 它们通过相同的关系指向相同的 VarNode

  • 它们出现在相同的使用点,并且将 map_free_vars 设置为 True

该类还重载了 operator() 方法,用于比较两个对象是否在结构上相等。该方法接受三个参数:左操作数、右操作数和布尔值 map_free_params,表示是否映射自由变量。返回值为比较结果。

SEqualReducer#

/*!
 * \brief A Reducer class to reduce the structural equality result of two objects.
 *
 * The reducer will call the SEqualReduce function of each objects recursively.
 * Importantly, the reducer may not directly use recursive calls to resolve the
 * equality checking. Instead, it can store the necessary equality conditions
 * and check later via an internally managed stack.
 */
class SEqualReducer {
 private:
  struct PathTracingData;

 public:
  /*! \brief Internal handler that defines custom behaviors.. */
  class Handler {
   public:
    /*!
     * \brief Reduce condition to equality of lhs and rhs.
     *
     * \param lhs The left operand.
     * \param rhs The right operand.
     * \param map_free_vars Whether do we allow remap variables if possible.
     * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
     *
     * \return false if there is an immediate failure, true otherwise.
     * \note This function may save the equality condition of (lhs == rhs) in an internal
     *       stack and try to resolve later.
     */
    virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
                              const Optional<ObjectPathPair>& current_paths) = 0;

    /*!
     * \brief Mark the comparison as failed, but don't fail immediately.
     *
     * This is useful for producing better error messages when comparing containers.
     * For example, if two array sizes mismatch, it's better to mark the comparison as failed
     * but compare array elements anyway, so that we could find the true first mismatch.
     */
    virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;

    /*!
     * \brief Check if fail defferal is enabled.
     *
     * \return false if the fail deferral is not enabled, true otherwise.
     */
    virtual bool IsFailDeferralEnabled() = 0;

    /*!
     * \brief Lookup the graph node equal map for vars that are already mapped.
     *
     *  This is an auxiliary method to check the Map<Var, Value> equality.
     * \param lhs an lhs value.
     *
     * \return The corresponding rhs value if any, nullptr if not available.
     */
    virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
    /*!
     * \brief Mark current comparison as graph node equal comparison.
     */
    virtual void MarkGraphNode() = 0;

   protected:
    using PathTracingData = SEqualReducer::PathTracingData;
  };

  /*! \brief default constructor */
  SEqualReducer() = default;
  /*!
   * \brief Constructor with a specific handler.
   * \param handler The equal handler for objects.
   * \param tracing_data Optional pointer to the path tracing data.
   * \param map_free_vars Whether or not to map free variables.
   */
  explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
      : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}

  /*!
   * \brief Reduce condition to comparison of two attribute values.
   *
   * \param lhs The left operand.
   *
   * \param rhs The right operand.
   *
   * \param paths The paths to the LHS and RHS operands.  If
   * unspecified, will attempt to identify the attribute's address
   * within the most recent ObjectRef.  In general, the paths only
   * require explicit handling for computed parameters
   * (e.g. `array.size()`)
   *
   * \return the immediate check result.
   */
  bool operator()(const double& lhs, const double& rhs,
                  Optional<ObjectPathPair> paths = NullOpt) const;
  bool operator()(const int64_t& lhs, const int64_t& rhs,
                  Optional<ObjectPathPair> paths = NullOpt) const;
  bool operator()(const uint64_t& lhs, const uint64_t& rhs,
                  Optional<ObjectPathPair> paths = NullOpt) const;
  bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
  bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
  bool operator()(const std::string& lhs, const std::string& rhs,
                  Optional<ObjectPathPair> paths = NullOpt) const;
  bool operator()(const DataType& lhs, const DataType& rhs,
                  Optional<ObjectPathPair> paths = NullOpt) const;

  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
  bool operator()(const ENum& lhs, const ENum& rhs,
                  Optional<ObjectPathPair> paths = NullOpt) const {
    using Underlying = typename std::underlying_type<ENum>::type;
    static_assert(std::is_same<Underlying, int>::value,
                  "Enum must have `int` as the underlying type");
    return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
  }

  template <typename T, typename Callable,
            typename = std::enable_if_t<
                std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
  bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
    if (IsPathTracingEnabled()) {
      ObjectPathPair current_paths = GetCurrentObjectPaths();
      ObjectPathPair new_paths = {callable(current_paths->lhs_path),
                                  callable(current_paths->rhs_path)};
      return (*this)(lhs, rhs, new_paths);
    } else {
      return (*this)(lhs, rhs);
    }
  }

  /*!
   * \brief Reduce condition to comparison of two objects.
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \return the immediate check result.
   */
  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

  /*!
   * \brief Reduce condition to comparison of two objects.
   *
   * Like `operator()`, but with an additional `paths` parameter that specifies explicit object
   * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
   * objects like Array and Map, or other custom objects that store nested objects that are not
   * simply attributes.
   *
   * Can only be called when `IsPathTracingEnabled()` is `true`.
   *
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \param paths Object paths for `lhs` and `rhs`.
   * \return the immediate check result.
   */
  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
    ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
    return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
  }

  /*!
   * \brief Reduce condition to comparison of two definitions,
   *        where free vars can be mapped.
   *
   *  Call this function to compare definition points such as function params
   *  and var in a let-binding.
   *
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \return the immediate check result.
   */
  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);

  /*!
   * \brief Reduce condition to comparison of two arrays.
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \return the immediate check result.
   */
  template <typename T>
  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
    if (tracing_data_ == nullptr) {
      // quick specialization for Array to reduce amount of recursion
      // depth as array comparison is pretty common.
      if (lhs.size() != rhs.size()) return false;
      for (size_t i = 0; i < lhs.size(); ++i) {
        if (!(operator()(lhs[i], rhs[i]))) return false;
      }
      return true;
    }

    // If tracing is enabled, fall back to the regular path
    const ObjectRef& lhs_obj = lhs;
    const ObjectRef& rhs_obj = rhs;
    return (*this)(lhs_obj, rhs_obj);
  }
  /*!
   * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \return the result.
   */
  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
    // var need to be remapped, so it belongs to graph node.
    handler_->MarkGraphNode();
    // We only map free vars if they corresponds to the same address
    // or map free_var option is set to be true.
    return lhs == rhs || map_free_vars_;
  }

  /*! \return Get the internal handler. */
  Handler* operator->() const { return handler_; }

  /*! \brief Check if this reducer is tracing paths to the first mismatch. */
  bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }

  /*!
   * \brief Get the paths of the currently compared objects.
   *
   * Can only be called when `IsPathTracingEnabled()` is true.
   */
  const ObjectPathPair& GetCurrentObjectPaths() const;

  /*!
   * \brief Specify the object paths of a detected mismatch.
   *
   * Can only be called when `IsPathTracingEnabled()` is true.
   */
  void RecordMismatchPaths(const ObjectPathPair& paths) const;

 private:
  bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
                      Optional<ObjectPathPair> paths = NullOpt) const;

  bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
                        const ObjectPathPair* paths) const;

  static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
                                                        const void* rhs_address,
                                                        const PathTracingData* tracing_data);

  template <typename T>
  static bool CompareAttributeValues(const T& lhs, const T& rhs,
                                     const PathTracingData* tracing_data,
                                     Optional<ObjectPathPair> paths = NullOpt);

  /*! \brief Internal class pointer. */
  Handler* handler_ = nullptr;
  /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
  const PathTracingData* tracing_data_ = nullptr;
  /*! \brief Whether or not to map free vars. */
  bool map_free_vars_ = false;
};

SEqualReducer 类,它是结构相等性结果的规约器。这个类的主要目的是通过调用每个对象的 SEqualReduce 函数来递归地比较两个对象是否相等。

SEqualReducer 类包含内部处理器 Handler,它定义了一些自定义行为的方法,如 SEqualReduceDeferFailIsFailDeferralEnabledMapLhsToRhsMarkGraphNode。这些方法需要在派生类中实现。

SEqualReducer 类还包含一些重载的运算符,用于比较不同类型的值,如 doubleint64_tuint64_tintboolstd::stringDataType 和枚举类型。这些运算符将调用相应的比较方法,如 EnumAttrsEqualObjectAttrsEqual

此外,SEqualReducer 类还提供了一些其他方法,如 operator() 重载,用于比较两个对象是否相等;DefEqual 方法,用于比较两个定义点是否相等,其中自由变量可以被映射;以及一些辅助方法,如 IsPathTracingEnabledGetCurrentObjectPaths

SEqualHandlerDefault#

/*! \brief The default handler for equality testing.
 *
 * Users can derive from this class and override the DispatchSEqualReduce method,
 * to customize equality testing.
 */
class SEqualHandlerDefault : public SEqualReducer::Handler {
 public:
  SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
                       bool defer_fails);
  virtual ~SEqualHandlerDefault();

  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
                    const Optional<ObjectPathPair>& current_paths) override;
  void DeferFail(const ObjectPathPair& mismatch_paths) override;
  bool IsFailDeferralEnabled() override;
  ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
  void MarkGraphNode() override;

  /*!
   * \brief The entry point for equality testing
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \param map_free_vars Whether or not to remap variables if possible.
   * \return The equality result.
   */
  virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars);

 protected:
  /*!
   * \brief The dispatcher for equality testing of intermediate objects
   * \param lhs The left operand.
   * \param rhs The right operand.
   * \param map_free_vars Whether or not to remap variables if possible.
   * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
   * \return The equality result.
   */
  virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
                                    const Optional<ObjectPathPair>& current_paths);

 private:
  class Impl;
  Impl* impl;
};

这段代码定义了一个名为SEqualHandlerDefault的类,它是SEqualReducer::Handler类的默认处理器。这个类用于处理对象之间的相等性测试。

该类具有以下成员函数:

  • SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch, bool defer_fails):构造函数,接受三个参数:assert_mode表示是否使用断言模式,first_mismatch是一个指向可选的对象路径对的指针,defer_fails表示是否延迟失败。

  • ~SEqualHandlerDefault():析构函数。

  • SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional<ObjectPathPair>& current_paths):重载的SEqualReduce方法,用于比较两个对象是否相等。它接受四个参数:左操作数lhs、右操作数rhs、是否映射自由变量map_free_vars以及当前路径对current_paths

  • DeferFail(const ObjectPathPair& mismatch_paths):延迟失败的方法,接受一个对象路径对作为参数。

  • IsFailDeferralEnabled():检查是否启用了失败延迟。

  • MapLhsToRhs(const ObjectRef& lhs):将左操作数映射到右操作数的方法,接受一个对象引用作为参数。

  • MarkGraphNode():标记图形节点的方法。

  • Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars):入口点方法,用于进行相等性测试。它接受三个参数:左操作数lhs、右操作数rhs以及是否映射自由变量map_free_vars

  • DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional<ObjectPathPair>& current_paths):中间对象的相等性测试调度器方法,接受四个参数:左操作数lhs、右操作数rhs、是否映射自由变量map_free_vars以及当前路径对current_paths

此外,该类还包含一个私有成员变量impl,它是一个指向内部实现类的指针。

Python 接口:structural_equal#

def structural_equal(lhs, rhs, map_free_vars=False):
    """Check structural equality of lhs and rhs.

    The structural equality is recursively defined in the DAG of IRNodes.
    There are two kinds of nodes:

    - Graph node: a graph node in lhs can only be mapped as equal to
      one and only one graph node in rhs.
    - Normal node: equality is recursively defined without the restriction
      of graph nodes.

    Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
    For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
    to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.

    A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
    with the same type if one of the following condition holds:

    - They appear in a same definition point(e.g. function argument).
    - They points to the same VarNode via the same_as relation.
    - They appear in a same usage point, and map_free_vars is set to be True.

    The rules for var are used to remap variables occurs in function
    arguments and let-bindings.

    Parameters
    ----------
    lhs : Object
        The left operand.

    rhs : Object
        The left operand.

    map_free_vars : bool
        Whether free variables (i.e. variables without a definition site) should be mapped
        as equal to each other.

    Return
    ------
    result : bool
        The comparison result.

    See Also
    --------
    structural_hash
    assert_strucural_equal
    """
    lhs = tvm.runtime.convert(lhs)
    rhs = tvm.runtime.convert(rhs)
    return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))  # type: ignore # pylint: disable=no-member

structural_equal 函数,用于检查两个对象(lhsrhs)在结构上是否相等。

该函数接受三个参数:

  • lhs:左操作数,表示要比较的第一个对象。

  • rhs:右操作数,表示要比较的第二个对象。

  • map_free_vars:布尔值,表示是否将自由变量(即没有定义 site 的变量)映射为相等。

函数首先使用 tvm.runtime.convertlhsrhs 转换为相应的类型。然后,它调用 _ffi_node_api.StructuralEqual 函数来执行实际的结构相等性检查,并将结果转换为布尔值返回。

该函数的实现考虑了两种类型的节点:图节点和非图节点。图节点是指那些在 lhs 中只能映射为与 rhs 中的一个且只有一个图节点相等的节点。非图节点则通过递归定义等价性,不受到图节点的限制。

对于图节点(如 tir::VarTypeVar),如果满足以下条件之一,它们可以映射为相同类型的另一个变量相等:

  • 它们出现在相同的定义点(例如函数参数)。

  • 它们通过相同的关系指向相同的 VarNode

  • 它们出现在相同的使用点,并且 map_free_vars 设置为 True

这些规则用于重新映射函数参数和 let 绑定中的变量。

最后,函数返回布尔值,表示比较的结果。