Relax Python 模块设计概述#

参考:Relax Python 模块设计 & pull: 18229

随着机器学习模型——尤其是大型语言模型——的规模持续增长,对 ML 编译器运行时与 Python 生态系统深度集成的需求日益增加。像 PyTorch 这样的基于 Python 的框架提供了丰富的算子库,包括通过 torch.distributed 进行分布式通信等功能,这些功能可以在 GPU 和节点之间高效扩展。这些资源已经被广泛采用并得到良好支持,使其成为编译器运行时中重用的理想候选。

在 TVM 中,计算图使用 Relax 在 IRModules 中描述。虽然 TVMScript 允许使用类似 Python 的语法表达 Relax 函数,但这些函数不能直接在 Python 中执行。要运行 Relax 函数,必须编译整个 IRModule,并通过虚拟机(VM)加载生成的可执行文件。

为了更好地利用 Python 的运行时环境并丰富 TVM 的灵活性,在支持 Python 的平台上的 IRModules 和 TVMScript 中添加对 Python 函数的原生支持。这些 Python 函数——用 @py_func 装饰器标记的——可以直接在 Python 中执行,使用标准的 PyTorch 张量作为输入和输出。类似于 Relax 函数,它们表示计算图,但额外的好处是可以直接、逐步地用 Python 执行。与需要在运行前编译的 Relax 函数不同,Python 函数将不会编译,而是可以直接在 Python 环境中运行。

除了重用 Python 和 PyTorch 实现,在 TVMScript 中支持 Python 函数可以显著提升调试体验。传统编译器将计算图视为单一实体,难以检查中间张量值。随着模型复杂性的增加,这一限制变得更加明显。通过 Python 函数,调试就像插入一条 print 语句一样简单。用户还可以快速手动编辑 Python 函数并立即观察结果——极大地改进了开发和调试工作流程。

关键设计#

跨层级调用#

Relax 中的 Python 函数设计为跨层级,意味着它们可以与 Relax 函数、TIR 函数和 TVM 打包函数互操作。这种双向互操作性允许:

  • 调用 Relax/TIR/打包函数的 Python 函数。

  • 通过 R.call_py_func 调用 Python 函数的 Relax 函数。

为了支持这一点,使用 DLPack 实现 TVM NDArrays 和 PyTorch 张量之间的无缝转换,使数据能够在不同的运行环境中流动,并最小化开销。

即时编译(JIT)#

如果 IRModule 包含任何 Python 函数,会使用 JIT 编译延迟 TIR 和 Relax 函数的编译。这意味着:

  • TVMScript 解析时不会编译 TIR 和 Relax 函数。

  • 编译仅在实例化 IRModule 时发生,此时:

    • TIR 函数会被编译并存储在实例化的模块中。

    • 会创建 Relax 虚拟机来执行编译后的 Relax 函数。

这种 JIT 策略允许更灵活的后期修改和与 Python 运行时的集成。

Relax 函数与 Python 函数之间的转换#

由于 Relax 函数和 Python 函数都描述计算图,引入了一种新的 IRModule 打印器,将 Relax 函数转换为 Python 函数。这允许用户:

  • 避免手动编写 Python 函数。

  • 将 Relax IR 转换为可读和可执行的 Python 代码。

  • 直接在 Python/PyTorch 中调试或部署中间阶段的 Relax 程序。

在此转换过程中:

  • 高级 Relax 算子(例如 R.nn.relu )映射到相应的 PyTorch API(例如 F.relu )。

  • call_tir 和 Relax 函数调用通过将 PyTorch 张量转换为/从 DLPack 格式并传递给编译函数来处理。

  • call_dps_packed 通过通过 tvm.get_global_func 检索压缩函数并使用 DLPack 包装的张量调用它来执行。

这一关键特性是这种转换可以在编译过程的任何阶段发生。例如:

  • 在早期阶段,用户可以将 Relax 函数转换为 Python,以测试 PyTorch 的实现。

  • 在后期阶段,当模块的大部分内容被降低到 TIR 时,相同的转换允许使用 PyTorch 运行时进行测试或部署。

未来,可能还会使用一些 PyTorch 基础设施(如 FX 或导出的程序)将 Python 函数跟踪回 Relax 函数。

具体实现#

通过 @I.pyfunc 装饰器和 BasePyModule ,在 TVM Relax 中实现了原生 Python 函数支持,这使 TVM 的编译流程与 Python/PyTorch 运行时环境之间能够无缝集成。这一增强功能允许用户直接在 TVMScript 中编写 Python 函数,这些函数可以与 Relax 和 TIR 函数互操作,从而提供增强的调试能力并利用现有的 PyTorch 算子库。

TVMScript 解析器增强

  • @I.pyfunc 装饰器:用于将 Python 函数标记为 IRModule 的集成目标

  • 双重存储格式:既存储原始字符串表示形式(用于 TVMScript 打印),也捕获 PackedFunc(用于运行时执行)

  • ExternFunc 表示:每个 Python 函数都表示为一个 ExternFunc 节点,节点属性存储源代码和运行时包装器

完整的 BasePyModule 实现

  • 基于 DLPack 的张量转换:PyTorch 张量和 TVM NDArrays 之间的无缝转换

  • 跨函数互操作性:Python 函数可以调用 Relax/TIR 函数,反之亦然

  • JIT 编译:延迟模块实例化时的编译,以支持灵活的后期修改

  • 动态函数注册:支持运行时添加 Python 函数

示例#

# 导入 TVM 核心模块
import tvm
from tvm import relax, tir
# 导入 BasePyModule,这是支持 Python 函数的 IRModule 基类
from tvm.relax.base_py_module import BasePyModule
# 导入 TVM script 相关模块,用于编写 IR、Relax 和 TIR 代码
from tvm.script import ir as I, relax as R, tir as T
# 导入设备相关模块
from tvm.runtime import Device
# 导入 PyTorch,用于演示跨框架数据转换
import torch


# 使用 @I.ir_module 装饰器定义一个 IR 模块,该模块继承自 BasePyModule
@I.ir_module
class IRModuleWithPyFunc(BasePyModule):
    """示例 IRModule 包含 Python 函数支持。
    基类 BasePyModule 实现了 Python 中的跨函数调用和 JIT 编译逻辑。
    只有继承自 BasePyModule 的 IRModule 才允许包含 Python 函数。
    """

    # 使用 @I.pyfunc 装饰器定义一个可以在 Relax 函数中调用的 Python 函数
    @I.pyfunc
    def python_add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """可以从 Relax 函数调用的 Python 函数。"""
        # 通过 DLPack 将 PyTorch 张量转换为 TVM NDArray
        x_tvm = self._convert_pytorch_to_tvm(x)
        y_tvm = self._convert_pytorch_to_tvm(y)
        
        # 调用编译后的 TIR 函数执行加法运算
        result = self.call_tir(self.add_tir, [x_tvm, y_tvm], 
                             out_sinfo=R.Tensor((5,), "float32"))
        
        # 将结果转换回原始格式(PyTorch 张量)
        return self._convert_tvm_to_pytorch(result)

    # 使用 @T.prim_func 装饰器定义一个 TIR 原语函数
    @T.prim_func
    def add_tir(
        var_x: T.handle,
        var_y: T.handle,
        var_out: T.handle,
    ):
        # 匹配缓冲区,将原始句柄绑定到具体的缓冲区描述
        x = T.match_buffer(var_x, (5,), "float32")
        y = T.match_buffer(var_y, (5,), "float32")
        out = T.match_buffer(var_out, (5,), "float32")
        
        # 实现向量加法运算
        for i in range(5):
            out[i] = x[i] + y[i]

    # 使用 @R.function 装饰器定义一个 Relax 函数
    @R.function
    def main_relax(x: R.Tensor((5,), "float32"), 
                   y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
        # 直接使用 Relax 的内置加法操作
        return R.add(x, y)


def main():
    """展示带有 Python 函数支持的 IRModule 的主函数。"""
    # 创建 IRModuleWithPyFunc 实例
    module = IRModuleWithPyFunc()
    
    # 生成用于测试的随机 PyTorch 张量
    x_torch = torch.randn(5, dtype=torch.float32)
    y_torch = torch.randn(5, dtype=torch.float32)
    
    # 通过 DLPack 将 PyTorch 张量转换为 TVM NDArray
    x_tvm = module._convert_pytorch_to_tvm(x_torch)
    y_tvm = module._convert_pytorch_to_tvm(y_torch)
    
    # 将 TVM NDArray 转换回 PyTorch 张量
    x_back = module._convert_tvm_to_pytorch(x_tvm)
    y_back = module._convert_tvm_to_pytorch(y_tvm)
    
    # 执行跨函数调用测试
    # 1. 调用 TIR 函数
    tir_result = module.call_tir("add_tir", [x_torch, y_torch], 
                                out_sinfo=R.Tensor((5,), "float32"))
    # 2. 调用 Relax 函数
    relax_result = module.main_relax(x_torch, y_torch)
    # 3. 调用 Python 函数
    python_result = module.python_add(x_torch, y_torch)
    
    # 返回模块实例、DLPack 转换结果和跨函数调用结果
    return module, (x_torch, y_torch, x_tvm, y_tvm, x_back, y_back), (tir_result, relax_result, python_result)


# 当脚本直接运行时执行主函数
if __name__ == "__main__":
    main()



# 示例用法与验证代码(当前被注释掉)
# result = main()
# assert result is not None, "函数应返回结果"
# module, dlpack_results, cross_call_results = result
# assert len(dlpack_results) == 6, "DLPack 结果应包含 6 个元素"
# assert len(cross_call_results) == 3, "跨调用结果应包含 3 个元素"