测试回调函数#
测试 TVM Relax 虚拟机(VM)中 Python 回调函数与 Relax 函数之间的交互机制,具体包括:
双向数据传递:测试 Relax 函数向Python回调函数传递张量,以及Python回调函数向Relax函数返回张量的能力。
异常传播:测试Python回调函数中抛出的异常是否能够正确传播到调用方,并且保留完整的堆栈跟踪信息。
多种执行模式:所有测试都在两种执行模式下运行 - bytecode(字节码)模式和compiled(编译)模式,以验证VM在不同执行策略下的行为一致性。
文件使用TVM的测试框架,通过参数化测试方式自动在不同配置下运行测试用例,确保Relax VM的回调机制能够正常工作。这些测试对于保证TVM Relax与Python之间的交互可靠性非常重要,特别是在需要混合执行Python代码和编译后代码的场景中。
import tvm
import tvm.testing
from tvm.script import relax as R
import numpy as np
# 定义执行模式参数化列表,测试将在字节码模式和编译模式下分别运行
EXEC_MODE = ["bytecode", "compiled"]
测试将张量从 Relax 函数传递给 Python 回调函数#
# 定义Relax函数,该函数接受一个张量和一个回调函数作为参数
@R.function
def relax_func(
A: R.Tensor([16], "int32"),
callback: R.Callable([R.Tensor([16], "int32")], R.Tuple([])),
):
# 将输入张量A乘以2
B = R.multiply(A, R.const(2))
# 调用回调函数,并传入计算结果B
_ = callback(B)
# 返回空元组
return R.tuple()
def test(exec_mode, target="llvm", dev=tvm.cpu()):
# 构建Relax模块,根据指定的执行模式(bytecode或compiled)
ex = tvm.relax.build(
tvm.IRModule.from_expr(relax_func),
target=target,
exec_mode=exec_mode,
)
# 创建虚拟机实例来执行编译后的模块
vm = tvm.relax.VirtualMachine(ex, dev)
# 全局变量,用于存储回调函数接收到的张量
from_callback = None
# 定义自定义的Python回调函数
def custom_callback(arr):
# 非局部变量声明,允许在内部函数中修改外部函数的变量
nonlocal from_callback
# 保存接收到的张量
from_callback = arr
# 创建测试数据
np_A = np.arange(16, dtype="int32")
# 将numpy数组转换为TVM NDArray
tvm_A = tvm.nd.array(np_A)
# 在虚拟机上执行Relax函数,传入张量和回调函数
vm["relax_func"](tvm_A, custom_callback)
# 验证回调函数是否被正确调用,并且接收到了正确的结果
assert from_callback is not None
np.testing.assert_array_equal(np_A * 2, from_callback.numpy())
for exec_mode in EXEC_MODE:
test(exec_mode, target="llvm")
测试从Python回调函数生成张量并传递给Relax函数#
# 定义Relax函数,该函数接受一个无参数但返回张量的回调函数
@R.function
def relax_func(
callback: R.Callable([], R.Tensor([16], "int32")),
):
# 调用回调函数获取张量
A = callback()
# 将获取的张量乘以2
B = R.multiply(A, R.const(2))
# 返回计算结果
return B
target = "llvm"
dev = tvm.cpu(0)
for exec_mode in EXEC_MODE:
# 构建Relax模块
ex = tvm.relax.build(
tvm.IRModule.from_expr(relax_func),
target=target,
exec_mode=exec_mode,
)
# 创建虚拟机实例
vm = tvm.relax.VirtualMachine(ex, dev)
# 创建测试数据
np_A = np.arange(16, dtype="int32")
# 定义自定义回调函数,返回测试数据的TVM NDArray
def custom_callback():
return tvm.nd.array(np_A)
# 在虚拟机上执行Relax函数,传入回调函数,并获取输出结果
output = vm["relax_func"](custom_callback)
# 验证输出结果是否正确
np.testing.assert_array_equal(np_A * 2, output.numpy())
测试Python回调函数中的异常是否能正确传播,并保留完整的堆栈跟踪#
# 定义Relax函数,接受一个返回张量的回调函数
@R.function
def relax_func(
callback: R.Callable([], R.Tensor([16], "int32")),
):
# 调用回调函数获取张量
A = callback()
# 返回获取的张量
return A
target = "llvm"
dev = tvm.cpu(0)
for exec_mode in EXEC_MODE:
# 构建Relax模块
ex = tvm.relax.build(
tvm.IRModule.from_expr(relax_func),
target=target,
exec_mode=exec_mode,
)
# 创建虚拟机实例
vm = tvm.relax.VirtualMachine(ex, dev)
# 定义一个会抛出异常的自定义回调函数
def custom_callback():
# 定义一个局部变量,用于测试堆栈中的局部变量信息
local_var = 42
# 抛出运行时错误
raise RuntimeError("Error thrown from callback")
try:
# 在虚拟机上执行Relax函数,传入会抛出异常的回调函数
vm["relax_func"](custom_callback)
except RuntimeError as err:
# 获取异常的堆栈跟踪
stack = err.__traceback__
# 遍历堆栈,找到最内层的栈帧(回调函数所在的栈帧)
while stack.tb_next is not None:
stack = stack.tb_next
frame = stack.tb_frame
# 验证最内层的栈帧是否来自Python回调函数
assert (
frame.f_code.co_filename.find("test_vm_callback_function.py") != -1
), "Inner-most stack frame should be from Python callback"
else:
# 如果没有捕获到异常,抛出错误
raise RuntimeError("Exception thrown in callback was not propagated to calling scope")