测试回调函数#

测试 TVM Relax 虚拟机(VM)中 Python 回调函数与 Relax 函数之间的交互机制,具体包括:

  1. 双向数据传递:测试 Relax 函数向Python回调函数传递张量,以及Python回调函数向Relax函数返回张量的能力。

  2. 异常传播:测试Python回调函数中抛出的异常是否能够正确传播到调用方,并且保留完整的堆栈跟踪信息。

  3. 多种执行模式:所有测试都在两种执行模式下运行 - bytecode(字节码)模式和compiled(编译)模式,以验证VM在不同执行策略下的行为一致性。

文件使用TVM的测试框架,通过参数化测试方式自动在不同配置下运行测试用例,确保Relax VM的回调机制能够正常工作。这些测试对于保证TVM Relax与Python之间的交互可靠性非常重要,特别是在需要混合执行Python代码和编译后代码的场景中。

import tvm
import tvm.testing

from tvm.script import relax as R
import numpy as np

# 定义执行模式参数化列表,测试将在字节码模式和编译模式下分别运行
EXEC_MODE = ["bytecode", "compiled"]

测试将张量从 Relax 函数传递给 Python 回调函数#

# 定义Relax函数,该函数接受一个张量和一个回调函数作为参数
@R.function
def relax_func(
    A: R.Tensor([16], "int32"),
    callback: R.Callable([R.Tensor([16], "int32")], R.Tuple([])),
):
    # 将输入张量A乘以2
    B = R.multiply(A, R.const(2))
    # 调用回调函数,并传入计算结果B
    _ = callback(B)
    # 返回空元组
    return R.tuple()
def test(exec_mode, target="llvm", dev=tvm.cpu()):
    # 构建Relax模块,根据指定的执行模式(bytecode或compiled)
    ex = tvm.relax.build(
        tvm.IRModule.from_expr(relax_func),
        target=target,
        exec_mode=exec_mode,
    )
    # 创建虚拟机实例来执行编译后的模块
    vm = tvm.relax.VirtualMachine(ex, dev)

    # 全局变量,用于存储回调函数接收到的张量
    from_callback = None

    # 定义自定义的Python回调函数
    def custom_callback(arr):
        # 非局部变量声明,允许在内部函数中修改外部函数的变量
        nonlocal from_callback
        # 保存接收到的张量
        from_callback = arr

    # 创建测试数据
    np_A = np.arange(16, dtype="int32")
    # 将numpy数组转换为TVM NDArray
    tvm_A = tvm.nd.array(np_A)

    # 在虚拟机上执行Relax函数,传入张量和回调函数
    vm["relax_func"](tvm_A, custom_callback)

    # 验证回调函数是否被正确调用,并且接收到了正确的结果
    assert from_callback is not None
    np.testing.assert_array_equal(np_A * 2, from_callback.numpy())
for exec_mode in EXEC_MODE:
    test(exec_mode, target="llvm")

测试从Python回调函数生成张量并传递给Relax函数#

# 定义Relax函数,该函数接受一个无参数但返回张量的回调函数
@R.function
def relax_func(
    callback: R.Callable([], R.Tensor([16], "int32")),
):
    # 调用回调函数获取张量
    A = callback()
    # 将获取的张量乘以2
    B = R.multiply(A, R.const(2))
    # 返回计算结果
    return B
target = "llvm"
dev = tvm.cpu(0)
for exec_mode in EXEC_MODE:
    # 构建Relax模块
    ex = tvm.relax.build(
        tvm.IRModule.from_expr(relax_func),
        target=target,
        exec_mode=exec_mode,
    )
    # 创建虚拟机实例
    vm = tvm.relax.VirtualMachine(ex, dev)

    # 创建测试数据
    np_A = np.arange(16, dtype="int32")

    # 定义自定义回调函数,返回测试数据的TVM NDArray
    def custom_callback():
        return tvm.nd.array(np_A)

    # 在虚拟机上执行Relax函数,传入回调函数,并获取输出结果
    output = vm["relax_func"](custom_callback)

    # 验证输出结果是否正确
    np.testing.assert_array_equal(np_A * 2, output.numpy())

测试Python回调函数中的异常是否能正确传播,并保留完整的堆栈跟踪#

# 定义Relax函数,接受一个返回张量的回调函数
@R.function
def relax_func(
    callback: R.Callable([], R.Tensor([16], "int32")),
):
    # 调用回调函数获取张量
    A = callback()
    # 返回获取的张量
    return A
target = "llvm"
dev = tvm.cpu(0)
for exec_mode in EXEC_MODE:
    # 构建Relax模块
    ex = tvm.relax.build(
        tvm.IRModule.from_expr(relax_func),
        target=target,
        exec_mode=exec_mode,
    )
    # 创建虚拟机实例
    vm = tvm.relax.VirtualMachine(ex, dev)

    # 定义一个会抛出异常的自定义回调函数
    def custom_callback():
        # 定义一个局部变量,用于测试堆栈中的局部变量信息
        local_var = 42
        # 抛出运行时错误
        raise RuntimeError("Error thrown from callback")

    try:
        # 在虚拟机上执行Relax函数,传入会抛出异常的回调函数
        vm["relax_func"](custom_callback)
    except RuntimeError as err:
        # 获取异常的堆栈跟踪
        stack = err.__traceback__
        # 遍历堆栈,找到最内层的栈帧(回调函数所在的栈帧)
        while stack.tb_next is not None:
            stack = stack.tb_next
        frame = stack.tb_frame
        # 验证最内层的栈帧是否来自Python回调函数
        assert (
            frame.f_code.co_filename.find("test_vm_callback_function.py") != -1
        ), "Inner-most stack frame should be from Python callback"

    else:
        # 如果没有捕获到异常,抛出错误
        raise RuntimeError("Exception thrown in callback was not propagated to calling scope")