表达式#
BaseExprNode
和 BaseExpr
#
BaseExprNode
是所有表达式的基本类型。该类继承自 Object
类,并声明了一些公共成员变量和方法:
span
是可变的Span
类型,指向原始源代码,用于保留调试信息。静态常量
_type_key
被设置为"BaseExpr"
,表示这个类的类型键值。静态常量
_type_has_method_sequal_reduce
和_type_has_method_shash_reduce
分别被设置为true
,表示这个类具有相等归约和哈希归约的方法。静态常量
_type_child_slots
被设置为 62,表示这个类的子节点槽数为 62。TVM_DECLARE_BASE_OBJECT_INFO
宏用于声明BaseExprNode
类的基本信息。
BaseExpr
类是一个托管引用(Managed Reference),它继承自 ObjectRef
类。该类通过宏 TVM_DEFINE_OBJECT_REF_METHODS
定义了对 BaseExprNode
的引用方法。
总的来说,这段代码定义了表达式的基本类型和引用方式,用于在后续的代码中进行表达式的操作和处理。
PrimExprNode
和 PrimExpr
#
PrimExprNode
是所有原语表达式(primitive expression)的基本节点,它继承自BaseExprNode
类。具有以下成员:dtype
:表示原语表达式的运行时数据类型。在编译时和运行时,runtime::DataType(dtype)
提供了粗糙的类型信息。它在PrimExpr
表达式构造期间被立即构建,并且可以用于快速类型检查。dtype
足以在原始表达式对应于像i32
这样的POD值类型时决定其类型。当dtype
为DataType::Handle()
时,表达式可能对应于更细粒度的类型,可以通过运行延迟类型推断来获取类型。此外,还定义了一些静态常量和对象打印相关的宏。
该类主要用于低级代码优化和整数分析,并在编译时和运行时提供粗糙的类型信息。
PrimExpr
是对PrimExprNode
的引用,它继承自BaseExpr
类。该类中声明了从整数和浮点数构造表达式的方法,并提供了获取表达式数据类型的成员函数dtype()
。此外,还使用宏TVM_DEFINE_OBJECT_REF_METHODS
定义了对象引用的方法。PrimExpr
类表示原语表达式,它支持各种算术运算符(如加法、减法、乘法、除法等)以及位运算符(如按位与、按位或、按位异或等)。这些运算符都支持常量折叠(eager constant folding),即在编译时尽可能将常量表达式计算出来,以减少运行时的计算量。此外,PrimExpr
类还提供了一些特化的运算符重载,例如operator==
、operator!=
、operator&&
等,用于支持布尔表达式的计算。
总的来说,这段代码提供了一种方便的方式来定义原语表达式的基本节点和对象引用,并支持类型信息和对象引用操作。
PrimExpr
的子类#
下面列出几个 PrimExpr
的子类,用于表示不同类型的原语表达式。
IntImmNode
类表示程序中的常量整数字面量。它继承自PrimExprNode
类,并包含一个整数值value
。该类还实现了一些访问器方法、相等比较方法和哈希方法。IntImm
类是一个托管引用类,用于管理IntImmNode
对象。它提供了构造函数以及对象引用方法的定义。FloatImmNode
类表示程序中的常量浮点数字面量。它也继承自PrimExprNode
类,并包含一个双精度浮点数值value
。该类同样实现了一些访问器方法、相等比较方法和哈希方法。FloatImm
类是一个托管引用类,用于管理FloatImmNode
对象。它提供了构造函数以及对象引用方法的定义。Bool
类表示布尔常量。它是一个托管引用类,继承自IntImm
类,并重载了一些运算符。
除了这些类之外,代码还定义了一些算子的重载,以确保我们使用最细粒度的类型进行运算。
这些类和算子重载提供了一种灵活的方式来表示和操作不同类型的原始表达式,以便在编译时进行优化和类型检查。
RelayExprNode
和 RelayExpr
#
RelayExprNode
是所有非原语表达式的基本节点。它继承自BaseExprNode
,表示一个表达式节点。这个类主要包含以下成员:checked_type_
:存储类型推断(类型检查)的结果。在类型推断之前可能是未定义的,在序列化期间会被丢弃。struct_info_
:存储表达式的结构信息,包括静态形状和运行时信息,如形状。virtual_device_
:该节点的虚拟设备(VirtualDevice),用于描述评估表达式结果应该存储在哪里。对于一阶表达式(非函数),它描述了结果应该存储在哪里。对于函数类型的表达式,虚拟设备描述了调用函数或闭包结果的存储位置(而不是函数本身的存储位置)。其他成员函数包括:
checked_type()
:返回已检查类型的引用。type_as()
:返回指定类型的 TTypeNode 指针。virtual_device()
:返回虚拟设备(VirtualDevice)。
类的成员变量还包括一些常量和元数据信息。
RelayExpr
是托管引用到RelayExprNode
的类。它继承自BaseExpr
,表示可管理的表达式。这个类主要包含以下成员:使用
TVM_DEFINE_OBJECT_REF_METHODS
宏定义了与BaseExpr
类的引用方法。
这些类提供了对表达式节点和表达式的管理和操作功能,使得可以对表达式进行类型推断、结构信息获取、虚拟设备设置等操作。
RelayExprNode
成员函数#
checked_type()
函数返回常量引用,表示该节点的已检查类型。它首先使用ICHECK
宏进行一些内部错误检查,确保checked_type_
字段已经被定义。然后返回checked_type_
字段的值。type_as()
函数是一个模板函数,用于将TTypeNode
类型的指针转换为RelayExprNode
类型的指针。它首先使用static_assert
进行编译时类型检查,确保TTypeNode
是TypeNode
的派生类。然后再次进行内部错误检查,确保checked_type_
字段已经被定义。接下来,它尝试将checked_type_
转换为TTypeNode
类型,并将结果存储在node
指针中。最后,它再次进行内部错误检查,确保node
指针不为空,并返回node
指针。
这些函数的目的是提供对节点类型的安全访问和类型转换的支持。
GlobalVarNode
和 GlobalVar
#
GlobalVarNode
类继承自RelayExprNode
,表示全局变量节点。这个类主要包含以下成员:name_hint
:表示变量名的提示,仅作为提示使用。VisitAttrs
方法:用于访问节点的属性,接受AttrVisitor
指针参数,并调用其Visit
方法来处理各个属性。SEqualReduce
方法:用于比较两个GlobalVarNode
对象是否相等,根据变量名进行比较,并调用FreeVarEqualImpl
方法进行其他属性的比较。SHashReduce
方法:用于计算节点的哈希值,调用HashReduce
方法进行计算。_type_key
:静态常量字符指针,用于标识该节点的类型为"GlobalVar"
。TVM_DECLARE_FINAL_OBJECT_INFO
宏:用于声明该类的最终对象信息。
GlobalVar
类继承自RelayExpr
,表示全局变量的托管引用。这个类主要包含以下成员:构造函数:接受字符串类型的
name_hint
参数作为变量名的提示,以及可选的Type
类型参数和Span
类型参数。TVM_DEFINE_OBJECT_REF_METHODS
宏:用于定义该类的引用方法,包括RelayExpr
类的引用方法和GlobalVarNode
类的引用方法。TVM_DEFINE_OBJECT_REF_COW_METHOD
宏:用于定义该类的可变引用方法,即复制引用方法。
这段代码的作用是定义了表示全局变量的类和节点,并提供了一些方法和属性来操作和管理这些全局变量。
表示范围和整数字面量#
Integer
类是一个容器类,用于存储和自动化类型检查属性,这些属性必须是常量整数。它继承自IntImm
类,并提供了不同的构造函数和赋值运算符重载。该类还实现了一些比较运算符,以支持范围比较。RangeNode
类表示一个范围节点,包含范围的最小值、范围大小和位置信息。它提供了访问器方法VisitAttrs
,用于在属性访问时进行处理。它还实现了相等比较方法和哈希方法,以便在编译时进行优化。Range
类是一个范围容器,用于表示一系列连续的整数。它提供了构造函数,可以通过范围的开始和结束值来创建范围对象。此外,它还提供了一个静态方法FromMinExtent
,用于通过最小值和范围大小来创建范围对象。
这些类可以用于表示程序中的常量整数和范围,并在编译时进行类型检查和优化。
PackedFuncValueConverter
#
三个模板特化:PackedFuncValueConverter<PrimExpr>
,PackedFuncValueConverter<tvm::Integer>
和 PackedFuncValueConverter<tvm::Bool>
用于处理在运行时的数据类型转换。
PackedFuncValueConverter<PrimExpr>
:这个模板特化处理的是原始表达式(PrimExpr)类型的数据。如果输入的TVMValue
是nullptr
,那么返回表示nullptr
的PrimExpr
对象。如果输入的TVMValue
是整数类型,那么根据其值的大小返回相应的IntImm
或Int32Imm
对象。如果输入的TVMValue
是浮点数类型,那么返回相应的FloatImm
对象。否则,将TVMValue
转换为ObjectRef
并调用PrimExpr::FromObject_
方法。PackedFuncValueConverter<tvm::Integer>
:这个模板特化处理的是整数类型(Integer
)的数据。如果输入的TVMValue
是nullptr
,那么返回表示nullptr
的Integer
对象。如果输入的TVMValue
是整数类型,那么直接返回该整数。否则,将TVMValue
转换为Integer
对象。PackedFuncValueConverter<tvm::Bool>
:这个模板特化处理的是布尔类型(Bool
)的数据。如果输入的TVMValue
是nullptr
,那么返回表示nullptr
的Bool
对象。如果输入的TVMValue
是整数类型,那么将其转换为bool
类型并返回相应的Bool
对象。否则,将TVMValue
转换为Bool
对象。