TVMScript @I.pyfunc 装饰器#
本测试验证以下内容:
@I.pyfunc装饰器的正确工作方式Python 函数如何正确集成到 IRModule 中
BasePyModule 继承关系是否正确处理
为 Python 函数创建 ExternFunc 节点的功能
# 导入测试框架和必要的库
import pytest
import torch
import tvm
from tvm import relax
from tvm.script import ir as I, relax as R, tir as T
from tvm.relax import BasePyModule
import numpy as np
@I.ir_module
class TestPyFuncModule(BasePyModule):
"""使用 @I.pyfunc 装饰器的 Python 函数测试模块。"""
@I.pyfunc
def pytorch_processor(x: torch.Tensor) -> torch.Tensor:
"""处理 PyTorch 张量的 Python 函数。"""
return torch.nn.functional.relu(x) * 2.0
@I.pyfunc
def pytorch_adder(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""将两个 PyTorch 张量相加的 Python 函数。"""
return x + y
@I.pyfunc
def pytorch_complex_ops(x: torch.Tensor) -> torch.Tensor:
"""复杂的 PyTorch 操作。"""
result = torch.nn.functional.softmax(x, dim=0)
result = torch.nn.functional.dropout(result, p=0.1, training=False)
return result * 10.0
@T.prim_func
def simple_tir_func(
var_A: T.handle,
var_B: T.handle,
):
"""简单的 TIR 函数,用于测试 Python 函数与 TIR 函数的集成。"""
T.func_attr({"tir.noalias": True})
n = T.int32()
A = T.match_buffer(var_A, (n,), "float32")
B = T.match_buffer(var_B, (n,), "float32")
for i in T.grid(n):
with T.block("copy"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi]
测试 @I.pyfunc 装饰器是否正确创建 pyfuncs 属性#
module = TestPyFuncModule
# 验证模块包含 pyfuncs 属性
assert hasattr(module, "pyfuncs"), "模块应该包含 pyfuncs 属性"
pyfuncs = module.pyfuncs
assert isinstance(pyfuncs, dict), "pyfuncs 应该是一个字典"
# 验证所有期望的函数都在 pyfuncs 中
expected_functions = ["pytorch_processor", "pytorch_adder", "pytorch_complex_ops"]
for func_name in expected_functions:
assert func_name in pyfuncs, f"函数 {func_name} 应该在 pyfuncs 中"
测试 pyfuncs 中的 Python 函数是否可调用#
module = TestPyFuncModule
pyfuncs = module.pyfuncs
# 测试 pytorch_processor 函数
processor_func = pyfuncs["pytorch_processor"]
assert callable(processor_func), "pytorch_processor 应该是可调用的"
# 测试 pytorch_adder 函数
adder_func = pyfuncs["pytorch_adder"]
assert callable(adder_func), "pytorch_adder 应该是可调用的"
# 测试 pytorch_complex_ops 函数
complex_func = pyfuncs["pytorch_complex_ops"]
assert callable(complex_func), "pytorch_complex_ops 应该是可调用的"
测试 Python 函数是否正确执行#
module = TestPyFuncModule
pyfuncs = module.pyfuncs
# 创建测试数据
x = torch.tensor([1.0, -2.0, 3.0, -4.0, 5.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
# 测试 pytorch_processor 函数
processor_func = pyfuncs["pytorch_processor"]
processor_result = processor_func(x)
assert isinstance(processor_result, torch.Tensor)
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(processor_result, expected, atol=1e-5)
# 测试 pytorch_adder 函数
adder_func = pyfuncs["pytorch_adder"]
adder_result = adder_func(x, y)
assert isinstance(adder_result, torch.Tensor)
expected = x + y
assert torch.allclose(adder_result, expected, atol=1e-5)
# 测试 pytorch_complex_ops 函数
complex_func = pyfuncs["pytorch_complex_ops"]
complex_result = complex_func(x)
assert isinstance(complex_result, torch.Tensor)
# 注意:dropout 是非确定性的,所以我们只检查形状和类型
assert complex_result.shape == x.shape
assert complex_result.dtype == x.dtype
测试模块是否包含用于 IRModule 操作的 functions 属性#
module = TestPyFuncModule
# 检查 functions 属性是否存在
assert hasattr(module, "functions"), "模块应该包含 functions 属性"
functions = module.functions
# TVM IRModule.functions 不是标准字典,但具有类似字典的行为
assert hasattr(functions, "__getitem__"), "functions 应该支持类似字典的访问"
assert hasattr(functions, "__iter__"), "functions 应该是可迭代的"
测试模块是否有用于 TVMScript 输出的 script() 方法#
module = TestPyFuncModule
# 检查 script 方法是否存在
assert hasattr(module, "script"), "模块应该包含 script 方法"
# 测试 script 方法的执行
script_output = module.script()
assert isinstance(script_output, str), "script() 应该返回一个字符串"
assert len(script_output) > 0, "script() 应该返回非空字符串"
测试模块是否具有 BasePyModule 继承标志#
module = TestPyFuncModule
# 检查继承标志是否存在(这可能在所有实现中都不设置)
if hasattr(module, "_base_py_module_inherited"):
assert module._base_py_module_inherited, "继承标志应该为 True"
else:
# 替代方法:检查模块是否支持 Python 函数
assert hasattr(module, "pyfuncs"), "模块应该支持 Python 函数"
# 检查原始类是否被保留(这可能在所有实现中都不设置)
if hasattr(module, "_original_class"):
assert module._original_class is not None, "原始类应该被保留"
else:
# 替代方法:检查模块是否可调用(ModuleFactory)
assert hasattr(module, "__call__"), "模块应该是可调用的(ModuleFactory)"
测试模块的创建和执行功能#
module = TestPyFuncModule
assert hasattr(module, "__call__"), "模块应该是可调用的"
device = tvm.cpu(0)
instance = module(device)
assert isinstance(instance, BasePyModule), "实例应该是 BasePyModule 类型"
assert hasattr(instance, "pyfuncs"), "实例应该包含 pyfuncs 属性"
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = instance.pytorch_processor(x)
assert isinstance(result, torch.Tensor)
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(result, expected, atol=1e-5)
Warning: Failed to compile Relax VM: 'NoneType' object has no attribute 'kind'
测试模块在 GPU 设备上的创建和执行功能#
module = TestPyFuncModule
if tvm.cuda().exist:
device = tvm.cuda(0)
instance = module(device)
assert isinstance(instance, BasePyModule), "实例应该是 BasePyModule 类型"
assert hasattr(instance, "pyfuncs"), "实例应该包含 pyfuncs 属性"
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda")
result = instance.pytorch_processor(x)
assert isinstance(result, torch.Tensor)
assert result.device.type == "cuda"
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(result, expected, atol=1e-5)
else:
pytest.skip("CUDA 不可用")
Warning: Failed to compile one or more TIR functions: Memory verification failed with the following errors:
Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Did you forget to bind?
# from tvm.script import tir as T
@T.prim_func
def simple_tir_func(var_A: T.handle, var_B: T.handle):
T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
n = T.int32()
A = T.match_buffer(var_A, (n,))
B = T.match_buffer(var_B, (n,))
for i in range(n):
B_1 = T.Buffer((n,), data=B.data)
A_1 = T.Buffer((n,), data=A.data)
B_1[i] = A_1[i]
Warning: Failed to compile Relax VM: Memory verification failed with the following errors:
Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Did you forget to bind?
# from tvm.script import tir as T
@T.prim_func
def simple_tir_func(var_A: T.handle, var_B: T.handle):
T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
n = T.int32()
A = T.match_buffer(var_A, (n,))
B = T.match_buffer(var_B, (n,))
for i in range(n):
B_1 = T.Buffer((n,), data=B.data)
A_1 = T.Buffer((n,), data=A.data)
B_1[i] = A_1[i]
测试 Python 函数如何与 TIR 函数协作#
module = TestPyFuncModule
# 创建实例
device = tvm.cpu(0)
instance = module(device)
# 测试 TIR 函数执行
n = 5
input_tensor = torch.randn(n, dtype=torch.float32)
# 调用 TIR 函数 - 它需要 3 个参数:输入、输出和大小
# 但 call_tir 会处理输出缓冲区的创建,所以我们只传递输入和大小
# 注意:TIR 函数期望 TVM 类型,而不是 Python 类型
result = instance.call_tir(
instance.simple_tir_func,
[input_tensor], # 只传递输入张量,让 call_tir 处理其余部分
R.Tensor((n,), "float32"),
)
# 验证结果
assert isinstance(result, torch.Tensor)
assert result.shape == (n,)
assert torch.allclose(result, input_tensor, atol=1e-5)
Warning: Failed to compile Relax VM: 'NoneType' object has no attribute 'kind'
测试 @I.pyfunc 装饰器是否保留函数签名#
module = TestPyFuncModule
pyfuncs = module.pyfuncs
# 检查函数签名
import inspect
# pytorch_processor 签名
processor_func = pyfuncs["pytorch_processor"]
sig = inspect.signature(processor_func)
params = list(sig.parameters.keys())
assert len(params) == 1, "pytorch_processor 应该有 1 个参数"
assert params[0] == "x", "第一个参数应该是 'x'"
# pytorch_adder 签名
adder_func = pyfuncs["pytorch_adder"]
sig = inspect.signature(adder_func)
params = list(sig.parameters.keys())
assert len(params) == 2, "pytorch_adder 应该有 2 个参数"
assert params[0] == "x", "第一个参数应该是 'x'"
assert params[1] == "y", "第二个参数应该是 'y'"