表达式#

BaseExprNodeBaseExpr#

BaseExprNode 是所有表达式的基本类型。该类继承自 Object 类,并声明了一些公共成员变量和方法:

  • span 是可变的 Span 类型,指向原始源代码,用于保留调试信息。

  • 静态常量 _type_key 被设置为 "BaseExpr",表示这个类的类型键值。

  • 静态常量 _type_has_method_sequal_reduce_type_has_method_shash_reduce 分别被设置为 true,表示这个类具有相等归约和哈希归约的方法。

  • 静态常量 _type_child_slots 被设置为 62,表示这个类的子节点槽数为 62。

  • TVM_DECLARE_BASE_OBJECT_INFO 宏用于声明 BaseExprNode 类的基本信息。

BaseExpr类是一个托管引用(Managed Reference),它继承自 ObjectRef 类。该类通过宏 TVM_DEFINE_OBJECT_REF_METHODS 定义了对 BaseExprNode 的引用方法。

总的来说,这段代码定义了表达式的基本类型和引用方式,用于在后续的代码中进行表达式的操作和处理。

PrimExprNodePrimExpr#

  • PrimExprNode 是所有原语表达式(primitive expression)的基本节点,它继承自 BaseExprNode 类。具有以下成员:

    • dtype:表示原语表达式的运行时数据类型。在编译时和运行时,runtime::DataType(dtype) 提供了粗糙的类型信息。它在 PrimExpr 表达式构造期间被立即构建,并且可以用于快速类型检查。dtype 足以在原始表达式对应于像i32这样的POD值类型时决定其类型。当 dtypeDataType::Handle() 时,表达式可能对应于更细粒度的类型,可以通过运行延迟类型推断来获取类型。

    • 此外,还定义了一些静态常量和对象打印相关的宏。

    • 该类主要用于低级代码优化和整数分析,并在编译时和运行时提供粗糙的类型信息。

  • PrimExpr 是对 PrimExprNode 的引用,它继承自 BaseExpr 类。该类中声明了从整数和浮点数构造表达式的方法,并提供了获取表达式数据类型的成员函数 dtype()。此外,还使用宏 TVM_DEFINE_OBJECT_REF_METHODS 定义了对象引用的方法。PrimExpr 类表示原语表达式,它支持各种算术运算符(如加法、减法、乘法、除法等)以及位运算符(如按位与、按位或、按位异或等)。这些运算符都支持常量折叠(eager constant folding),即在编译时尽可能将常量表达式计算出来,以减少运行时的计算量。此外,PrimExpr 类还提供了一些特化的运算符重载,例如 operator==operator!=operator&& 等,用于支持布尔表达式的计算。

总的来说,这段代码提供了一种方便的方式来定义原语表达式的基本节点和对象引用,并支持类型信息和对象引用操作。

PrimExpr 的子类#

下面列出几个 PrimExpr 的子类,用于表示不同类型的原语表达式。

  1. IntImmNode 类表示程序中的常量整数字面量。它继承自 PrimExprNode 类,并包含一个整数值 value。该类还实现了一些访问器方法、相等比较方法和哈希方法。

  2. IntImm 类是一个托管引用类,用于管理 IntImmNode 对象。它提供了构造函数以及对象引用方法的定义。

  3. FloatImmNode 类表示程序中的常量浮点数字面量。它也继承自 PrimExprNode 类,并包含一个双精度浮点数值 value。该类同样实现了一些访问器方法、相等比较方法和哈希方法。

  4. FloatImm 类是一个托管引用类,用于管理 FloatImmNode 对象。它提供了构造函数以及对象引用方法的定义。

  5. Bool 类表示布尔常量。它是一个托管引用类,继承自 IntImm 类,并重载了一些运算符。

除了这些类之外,代码还定义了一些算子的重载,以确保我们使用最细粒度的类型进行运算。

这些类和算子重载提供了一种灵活的方式来表示和操作不同类型的原始表达式,以便在编译时进行优化和类型检查。

RelayExprNodeRelayExpr#

  1. RelayExprNode 是所有非原语表达式的基本节点。它继承自 BaseExprNode,表示一个表达式节点。这个类主要包含以下成员:

    • checked_type_:存储类型推断(类型检查)的结果。在类型推断之前可能是未定义的,在序列化期间会被丢弃。

    • struct_info_:存储表达式的结构信息,包括静态形状和运行时信息,如形状。

    • virtual_device_:该节点的虚拟设备(VirtualDevice),用于描述评估表达式结果应该存储在哪里。对于一阶表达式(非函数),它描述了结果应该存储在哪里。对于函数类型的表达式,虚拟设备描述了调用函数或闭包结果的存储位置(而不是函数本身的存储位置)。

    • 其他成员函数包括:

      • checked_type():返回已检查类型的引用。

      • type_as():返回指定类型的 TTypeNode 指针。

      • virtual_device():返回虚拟设备(VirtualDevice)。

    • 类的成员变量还包括一些常量和元数据信息。

  2. RelayExpr 是托管引用到 RelayExprNode 的类。它继承自 BaseExpr,表示可管理的表达式。这个类主要包含以下成员:

    • 使用 TVM_DEFINE_OBJECT_REF_METHODS 宏定义了与 BaseExpr 类的引用方法。

这些类提供了对表达式节点和表达式的管理和操作功能,使得可以对表达式进行类型推断、结构信息获取、虚拟设备设置等操作。

RelayExprNode 成员函数#

  1. checked_type() 函数返回常量引用,表示该节点的已检查类型。它首先使用 ICHECK 宏进行一些内部错误检查,确保 checked_type_ 字段已经被定义。然后返回 checked_type_ 字段的值。

  2. type_as() 函数是一个模板函数,用于将 TTypeNode 类型的指针转换为 RelayExprNode 类型的指针。它首先使用 static_assert 进行编译时类型检查,确保 TTypeNodeTypeNode 的派生类。然后再次进行内部错误检查,确保 checked_type_ 字段已经被定义。接下来,它尝试将 checked_type_ 转换为 TTypeNode 类型,并将结果存储在 node 指针中。最后,它再次进行内部错误检查,确保 node 指针不为空,并返回 node 指针。

这些函数的目的是提供对节点类型的安全访问和类型转换的支持。

GlobalVarNodeGlobalVar#

  1. GlobalVarNode 类继承自 RelayExprNode,表示全局变量节点。这个类主要包含以下成员:

    • name_hint:表示变量名的提示,仅作为提示使用。

    • VisitAttrs 方法:用于访问节点的属性,接受 AttrVisitor 指针参数,并调用其 Visit 方法来处理各个属性。

    • SEqualReduce 方法:用于比较两个 GlobalVarNode 对象是否相等,根据变量名进行比较,并调用 FreeVarEqualImpl 方法进行其他属性的比较。

    • SHashReduce 方法:用于计算节点的哈希值,调用 HashReduce 方法进行计算。

    • _type_key:静态常量字符指针,用于标识该节点的类型为 "GlobalVar"

    • TVM_DECLARE_FINAL_OBJECT_INFO 宏:用于声明该类的最终对象信息。

  2. GlobalVar 类继承自 RelayExpr,表示全局变量的托管引用。这个类主要包含以下成员:

    • 构造函数:接受字符串类型的 name_hint 参数作为变量名的提示,以及可选的 Type 类型参数和 Span 类型参数。

    • TVM_DEFINE_OBJECT_REF_METHODS 宏:用于定义该类的引用方法,包括 RelayExpr 类的引用方法和 GlobalVarNode 类的引用方法。

    • TVM_DEFINE_OBJECT_REF_COW_METHOD 宏:用于定义该类的可变引用方法,即复制引用方法。

这段代码的作用是定义了表示全局变量的类和节点,并提供了一些方法和属性来操作和管理这些全局变量。

表示范围和整数字面量#

  1. Integer 类是一个容器类,用于存储和自动化类型检查属性,这些属性必须是常量整数。它继承自 IntImm 类,并提供了不同的构造函数和赋值运算符重载。该类还实现了一些比较运算符,以支持范围比较。

  2. RangeNode 类表示一个范围节点,包含范围的最小值、范围大小和位置信息。它提供了访问器方法 VisitAttrs,用于在属性访问时进行处理。它还实现了相等比较方法和哈希方法,以便在编译时进行优化。

  3. Range 类是一个范围容器,用于表示一系列连续的整数。它提供了构造函数,可以通过范围的开始和结束值来创建范围对象。此外,它还提供了一个静态方法 FromMinExtent,用于通过最小值和范围大小来创建范围对象。

这些类可以用于表示程序中的常量整数和范围,并在编译时进行类型检查和优化。

PackedFuncValueConverter#

三个模板特化:PackedFuncValueConverter<PrimExpr>PackedFuncValueConverter<tvm::Integer>PackedFuncValueConverter<tvm::Bool> 用于处理在运行时的数据类型转换。

  1. PackedFuncValueConverter<PrimExpr>:这个模板特化处理的是原始表达式(PrimExpr)类型的数据。如果输入的 TVMValuenullptr,那么返回表示 nullptrPrimExpr 对象。如果输入的 TVMValue 是整数类型,那么根据其值的大小返回相应的 IntImmInt32Imm 对象。如果输入的 TVMValue 是浮点数类型,那么返回相应的 FloatImm 对象。否则,将 TVMValue 转换为 ObjectRef 并调用 PrimExpr::FromObject_ 方法。

  2. PackedFuncValueConverter<tvm::Integer>:这个模板特化处理的是整数类型(Integer)的数据。如果输入的 TVMValuenullptr,那么返回表示 nullptrInteger 对象。如果输入的 TVMValue 是整数类型,那么直接返回该整数。否则,将 TVMValue 转换为 Integer 对象。

  3. PackedFuncValueConverter<tvm::Bool>:这个模板特化处理的是布尔类型(Bool)的数据。如果输入的 TVMValuenullptr,那么返回表示 nullptrBool 对象。如果输入的 TVMValue 是整数类型,那么将其转换为 bool 类型并返回相应的 Bool 对象。否则,将 TVMValue 转换为 Bool 对象。