Relax Python 模块设计概述#
参考:Relax Python 模块设计 & pull: 18229
随着机器学习模型——尤其是大型语言模型——的规模持续增长,对 ML 编译器运行时与 Python 生态系统深度集成的需求日益增加。像 PyTorch 这样的基于 Python 的框架提供了丰富的算子库,包括通过 torch.distributed 进行分布式通信等功能,这些功能可以在 GPU 和节点之间高效扩展。这些资源已经被广泛采用并得到良好支持,使其成为编译器运行时中重用的理想候选。
在 TVM 中,计算图使用 Relax 在 IRModules 中描述。虽然 TVMScript 允许使用类似 Python 的语法表达 Relax 函数,但这些函数不能直接在 Python 中执行。要运行 Relax 函数,必须编译整个 IRModule,并通过虚拟机(VM)加载生成的可执行文件。
为了更好地利用 Python 的运行时环境并丰富 TVM 的灵活性,在支持 Python 的平台上的 IRModules 和 TVMScript 中添加对 Python 函数的原生支持。这些 Python 函数——用 @py_func 装饰器标记的——可以直接在 Python 中执行,使用标准的 PyTorch 张量作为输入和输出。类似于 Relax 函数,它们表示计算图,但额外的好处是可以直接、逐步地用 Python 执行。与需要在运行前编译的 Relax 函数不同,Python 函数将不会编译,而是可以直接在 Python 环境中运行。
除了重用 Python 和 PyTorch 实现,在 TVMScript 中支持 Python 函数可以显著提升调试体验。传统编译器将计算图视为单一实体,难以检查中间张量值。随着模型复杂性的增加,这一限制变得更加明显。通过 Python 函数,调试就像插入一条 print 语句一样简单。用户还可以快速手动编辑 Python 函数并立即观察结果——极大地改进了开发和调试工作流程。
关键设计#
跨层级调用#
Relax 中的 Python 函数设计为跨层级,意味着它们可以与 Relax 函数、TIR 函数和 TVM 打包函数互操作。这种双向互操作性允许:
调用 Relax/TIR/打包函数的 Python 函数。
通过
R.call_py_func调用 Python 函数的 Relax 函数。
为了支持这一点,使用 DLPack 实现 TVM NDArrays 和 PyTorch 张量之间的无缝转换,使数据能够在不同的运行环境中流动,并最小化开销。
即时编译(JIT)#
如果 IRModule 包含任何 Python 函数,会使用 JIT 编译延迟 TIR 和 Relax 函数的编译。这意味着:
TVMScript 解析时不会编译 TIR 和 Relax 函数。
编译仅在实例化 IRModule 时发生,此时:
TIR 函数会被编译并存储在实例化的模块中。
会创建 Relax 虚拟机来执行编译后的 Relax 函数。
这种 JIT 策略允许更灵活的后期修改和与 Python 运行时的集成。
Relax 函数与 Python 函数之间的转换#
由于 Relax 函数和 Python 函数都描述计算图,引入了一种新的 IRModule 打印器,将 Relax 函数转换为 Python 函数。这允许用户:
避免手动编写 Python 函数。
将 Relax IR 转换为可读和可执行的 Python 代码。
直接在 Python/PyTorch 中调试或部署中间阶段的 Relax 程序。
在此转换过程中:
高级 Relax 算子(例如
R.nn.relu)映射到相应的 PyTorch API(例如F.relu)。call_tir和 Relax 函数调用通过将 PyTorch 张量转换为/从 DLPack 格式并传递给编译函数来处理。call_dps_packed通过通过tvm.get_global_func检索压缩函数并使用 DLPack 包装的张量调用它来执行。
这一关键特性是这种转换可以在编译过程的任何阶段发生。例如:
在早期阶段,用户可以将 Relax 函数转换为 Python,以测试 PyTorch 的实现。
在后期阶段,当模块的大部分内容被降低到 TIR 时,相同的转换允许使用 PyTorch 运行时进行测试或部署。
未来,可能还会使用一些 PyTorch 基础设施(如 FX 或导出的程序)将 Python 函数跟踪回 Relax 函数。
具体实现#
通过 @I.pyfunc 装饰器和 BasePyModule ,在 TVM Relax 中实现了原生 Python 函数支持,这使 TVM 的编译流程与 Python/PyTorch 运行时环境之间能够无缝集成。这一增强功能允许用户直接在 TVMScript 中编写 Python 函数,这些函数可以与 Relax 和 TIR 函数互操作,从而提供增强的调试能力并利用现有的 PyTorch 算子库。
TVMScript 解析器增强
@I.pyfunc装饰器:用于将 Python 函数标记为 IRModule 的集成目标双重存储格式:既存储原始字符串表示形式(用于 TVMScript 打印),也捕获 PackedFunc(用于运行时执行)
ExternFunc表示:每个 Python 函数都表示为一个 ExternFunc 节点,节点属性存储源代码和运行时包装器
完整的 BasePyModule 实现
基于 DLPack 的张量转换:PyTorch 张量和 TVM NDArrays 之间的无缝转换
跨函数互操作性:Python 函数可以调用 Relax/TIR 函数,反之亦然
JIT 编译:延迟模块实例化时的编译,以支持灵活的后期修改
动态函数注册:支持运行时添加 Python 函数
示例#
# 导入 TVM 核心模块
import tvm
from tvm import relax, tir
# 导入 BasePyModule,这是支持 Python 函数的 IRModule 基类
from tvm.relax.base_py_module import BasePyModule
# 导入 TVM script 相关模块,用于编写 IR、Relax 和 TIR 代码
from tvm.script import ir as I, relax as R, tir as T
# 导入设备相关模块
from tvm.runtime import Device
# 导入 PyTorch,用于演示跨框架数据转换
import torch
# 使用 @I.ir_module 装饰器定义一个 IR 模块,该模块继承自 BasePyModule
@I.ir_module
class IRModuleWithPyFunc(BasePyModule):
"""示例 IRModule 包含 Python 函数支持。
基类 BasePyModule 实现了 Python 中的跨函数调用和 JIT 编译逻辑。
只有继承自 BasePyModule 的 IRModule 才允许包含 Python 函数。
"""
# 使用 @I.pyfunc 装饰器定义一个可以在 Relax 函数中调用的 Python 函数
@I.pyfunc
def python_add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""可以从 Relax 函数调用的 Python 函数。"""
# 通过 DLPack 将 PyTorch 张量转换为 TVM NDArray
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
# 调用编译后的 TIR 函数执行加法运算
result = self.call_tir(self.add_tir, [x_tvm, y_tvm],
out_sinfo=R.Tensor((5,), "float32"))
# 将结果转换回原始格式(PyTorch 张量)
return self._convert_tvm_to_pytorch(result)
# 使用 @T.prim_func 装饰器定义一个 TIR 原语函数
@T.prim_func
def add_tir(
var_x: T.handle,
var_y: T.handle,
var_out: T.handle,
):
# 匹配缓冲区,将原始句柄绑定到具体的缓冲区描述
x = T.match_buffer(var_x, (5,), "float32")
y = T.match_buffer(var_y, (5,), "float32")
out = T.match_buffer(var_out, (5,), "float32")
# 实现向量加法运算
for i in range(5):
out[i] = x[i] + y[i]
# 使用 @R.function 装饰器定义一个 Relax 函数
@R.function
def main_relax(x: R.Tensor((5,), "float32"),
y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
# 直接使用 Relax 的内置加法操作
return R.add(x, y)
def main():
"""展示带有 Python 函数支持的 IRModule 的主函数。"""
# 创建 IRModuleWithPyFunc 实例
module = IRModuleWithPyFunc()
# 生成用于测试的随机 PyTorch 张量
x_torch = torch.randn(5, dtype=torch.float32)
y_torch = torch.randn(5, dtype=torch.float32)
# 通过 DLPack 将 PyTorch 张量转换为 TVM NDArray
x_tvm = module._convert_pytorch_to_tvm(x_torch)
y_tvm = module._convert_pytorch_to_tvm(y_torch)
# 将 TVM NDArray 转换回 PyTorch 张量
x_back = module._convert_tvm_to_pytorch(x_tvm)
y_back = module._convert_tvm_to_pytorch(y_tvm)
# 执行跨函数调用测试
# 1. 调用 TIR 函数
tir_result = module.call_tir("add_tir", [x_torch, y_torch],
out_sinfo=R.Tensor((5,), "float32"))
# 2. 调用 Relax 函数
relax_result = module.main_relax(x_torch, y_torch)
# 3. 调用 Python 函数
python_result = module.python_add(x_torch, y_torch)
# 返回模块实例、DLPack 转换结果和跨函数调用结果
return module, (x_torch, y_torch, x_tvm, y_tvm, x_back, y_back), (tir_result, relax_result, python_result)
# 当脚本直接运行时执行主函数
if __name__ == "__main__":
main()
# 示例用法与验证代码(当前被注释掉)
# result = main()
# assert result is not None, "函数应返回结果"
# module, dlpack_results, cross_call_results = result
# assert len(dlpack_results) == 6, "DLPack 结果应包含 6 个元素"
# assert len(cross_call_results) == 3, "跨调用结果应包含 3 个元素"