TVMScript @I.pyfunc 装饰器#

本测试验证以下内容:

  1. @I.pyfunc 装饰器的正确工作方式

  2. Python 函数如何正确集成到 IRModule 中

  3. BasePyModule 继承关系是否正确处理

  4. 为 Python 函数创建 ExternFunc 节点的功能

# 导入测试框架和必要的库
import pytest
import torch
import tvm
from tvm import relax
from tvm.script import ir as I, relax as R, tir as T
from tvm.relax import BasePyModule
import numpy as np


@I.ir_module
class TestPyFuncModule(BasePyModule):
    """使用 @I.pyfunc 装饰器的 Python 函数测试模块。"""

    @I.pyfunc
    def pytorch_processor(x: torch.Tensor) -> torch.Tensor:
        """处理 PyTorch 张量的 Python 函数。"""
        return torch.nn.functional.relu(x) * 2.0

    @I.pyfunc
    def pytorch_adder(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """将两个 PyTorch 张量相加的 Python 函数。"""
        return x + y

    @I.pyfunc
    def pytorch_complex_ops(x: torch.Tensor) -> torch.Tensor:
        """复杂的 PyTorch 操作。"""
        result = torch.nn.functional.softmax(x, dim=0)
        result = torch.nn.functional.dropout(result, p=0.1, training=False)
        return result * 10.0

    @T.prim_func
    def simple_tir_func(
        var_A: T.handle,
        var_B: T.handle,
    ):
        """简单的 TIR 函数,用于测试 Python 函数与 TIR 函数的集成。"""
        T.func_attr({"tir.noalias": True})
        n = T.int32()
        A = T.match_buffer(var_A, (n,), "float32")
        B = T.match_buffer(var_B, (n,), "float32")

        for i in T.grid(n):
            with T.block("copy"):
                vi = T.axis.remap("S", [i])
                B[vi] = A[vi]

测试 @I.pyfunc 装饰器是否正确创建 pyfuncs 属性#

module = TestPyFuncModule

# 验证模块包含 pyfuncs 属性
assert hasattr(module, "pyfuncs"), "模块应该包含 pyfuncs 属性"

pyfuncs = module.pyfuncs
assert isinstance(pyfuncs, dict), "pyfuncs 应该是一个字典"

# 验证所有期望的函数都在 pyfuncs 中
expected_functions = ["pytorch_processor", "pytorch_adder", "pytorch_complex_ops"]
for func_name in expected_functions:
    assert func_name in pyfuncs, f"函数 {func_name} 应该在 pyfuncs 中"

测试 pyfuncs 中的 Python 函数是否可调用#

module = TestPyFuncModule
pyfuncs = module.pyfuncs

# 测试 pytorch_processor 函数
processor_func = pyfuncs["pytorch_processor"]
assert callable(processor_func), "pytorch_processor 应该是可调用的"

# 测试 pytorch_adder 函数
adder_func = pyfuncs["pytorch_adder"]
assert callable(adder_func), "pytorch_adder 应该是可调用的"

# 测试 pytorch_complex_ops 函数
complex_func = pyfuncs["pytorch_complex_ops"]
assert callable(complex_func), "pytorch_complex_ops 应该是可调用的"

测试 Python 函数是否正确执行#

module = TestPyFuncModule
pyfuncs = module.pyfuncs

# 创建测试数据
x = torch.tensor([1.0, -2.0, 3.0, -4.0, 5.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)

# 测试 pytorch_processor 函数
processor_func = pyfuncs["pytorch_processor"]
processor_result = processor_func(x)

assert isinstance(processor_result, torch.Tensor)
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(processor_result, expected, atol=1e-5)

# 测试 pytorch_adder 函数
adder_func = pyfuncs["pytorch_adder"]
adder_result = adder_func(x, y)

assert isinstance(adder_result, torch.Tensor)
expected = x + y
assert torch.allclose(adder_result, expected, atol=1e-5)

# 测试 pytorch_complex_ops 函数
complex_func = pyfuncs["pytorch_complex_ops"]
complex_result = complex_func(x)

assert isinstance(complex_result, torch.Tensor)
# 注意:dropout 是非确定性的,所以我们只检查形状和类型
assert complex_result.shape == x.shape
assert complex_result.dtype == x.dtype

测试模块是否包含用于 IRModule 操作的 functions 属性#

module = TestPyFuncModule

# 检查 functions 属性是否存在
assert hasattr(module, "functions"), "模块应该包含 functions 属性"

functions = module.functions
# TVM IRModule.functions 不是标准字典,但具有类似字典的行为
assert hasattr(functions, "__getitem__"), "functions 应该支持类似字典的访问"
assert hasattr(functions, "__iter__"), "functions 应该是可迭代的"

测试模块是否有用于 TVMScript 输出的 script() 方法#

module = TestPyFuncModule

# 检查 script 方法是否存在
assert hasattr(module, "script"), "模块应该包含 script 方法"

# 测试 script 方法的执行
script_output = module.script()
assert isinstance(script_output, str), "script() 应该返回一个字符串"
assert len(script_output) > 0, "script() 应该返回非空字符串"

测试模块是否具有 BasePyModule 继承标志#

module = TestPyFuncModule

# 检查继承标志是否存在(这可能在所有实现中都不设置)
if hasattr(module, "_base_py_module_inherited"):
    assert module._base_py_module_inherited, "继承标志应该为 True"
else:
    # 替代方法:检查模块是否支持 Python 函数
    assert hasattr(module, "pyfuncs"), "模块应该支持 Python 函数"

# 检查原始类是否被保留(这可能在所有实现中都不设置)
if hasattr(module, "_original_class"):
    assert module._original_class is not None, "原始类应该被保留"
else:
    # 替代方法:检查模块是否可调用(ModuleFactory)
    assert hasattr(module, "__call__"), "模块应该是可调用的(ModuleFactory)"

测试模块的创建和执行功能#

module = TestPyFuncModule

assert hasattr(module, "__call__"), "模块应该是可调用的"

device = tvm.cpu(0)
instance = module(device)

assert isinstance(instance, BasePyModule), "实例应该是 BasePyModule 类型"
assert hasattr(instance, "pyfuncs"), "实例应该包含 pyfuncs 属性"

x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = instance.pytorch_processor(x)

assert isinstance(result, torch.Tensor)
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(result, expected, atol=1e-5)
Warning: Failed to compile Relax VM: 'NoneType' object has no attribute 'kind'

测试模块在 GPU 设备上的创建和执行功能#

module = TestPyFuncModule

if tvm.cuda().exist:
    device = tvm.cuda(0)
    instance = module(device)

    assert isinstance(instance, BasePyModule), "实例应该是 BasePyModule 类型"
    assert hasattr(instance, "pyfuncs"), "实例应该包含 pyfuncs 属性"

    x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda")
    result = instance.pytorch_processor(x)

    assert isinstance(result, torch.Tensor)
    assert result.device.type == "cuda"
    expected = torch.nn.functional.relu(x) * 2.0
    assert torch.allclose(result, expected, atol=1e-5)
else:
    pytest.skip("CUDA 不可用")
Warning: Failed to compile one or more TIR functions: Memory verification failed with the following errors:
    Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  Did you forget to bind?
# from tvm.script import tir as T

@T.prim_func
def simple_tir_func(var_A: T.handle, var_B: T.handle):
    T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
    n = T.int32()
    A = T.match_buffer(var_A, (n,))
    B = T.match_buffer(var_B, (n,))
    for i in range(n):
        B_1 = T.Buffer((n,), data=B.data)
        A_1 = T.Buffer((n,), data=A.data)
        B_1[i] = A_1[i]
Warning: Failed to compile Relax VM: Memory verification failed with the following errors:
    Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  Did you forget to bind?
# from tvm.script import tir as T

@T.prim_func
def simple_tir_func(var_A: T.handle, var_B: T.handle):
    T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
    n = T.int32()
    A = T.match_buffer(var_A, (n,))
    B = T.match_buffer(var_B, (n,))
    for i in range(n):
        B_1 = T.Buffer((n,), data=B.data)
        A_1 = T.Buffer((n,), data=A.data)
        B_1[i] = A_1[i]

测试 Python 函数如何与 TIR 函数协作#

module = TestPyFuncModule

# 创建实例
device = tvm.cpu(0)
instance = module(device)

# 测试 TIR 函数执行
n = 5
input_tensor = torch.randn(n, dtype=torch.float32)

# 调用 TIR 函数 - 它需要 3 个参数:输入、输出和大小
# 但 call_tir 会处理输出缓冲区的创建,所以我们只传递输入和大小
# 注意:TIR 函数期望 TVM 类型,而不是 Python 类型
result = instance.call_tir(
    instance.simple_tir_func,
    [input_tensor],  # 只传递输入张量,让 call_tir 处理其余部分
    R.Tensor((n,), "float32"),
)

# 验证结果
assert isinstance(result, torch.Tensor)
assert result.shape == (n,)
assert torch.allclose(result, input_tensor, atol=1e-5)
Warning: Failed to compile Relax VM: 'NoneType' object has no attribute 'kind'

测试 @I.pyfunc 装饰器是否保留函数签名#

module = TestPyFuncModule
pyfuncs = module.pyfuncs

# 检查函数签名
import inspect

# pytorch_processor 签名
processor_func = pyfuncs["pytorch_processor"]
sig = inspect.signature(processor_func)
params = list(sig.parameters.keys())
assert len(params) == 1, "pytorch_processor 应该有 1 个参数"
assert params[0] == "x", "第一个参数应该是 'x'"

# pytorch_adder 签名
adder_func = pyfuncs["pytorch_adder"]
sig = inspect.signature(adder_func)
params = list(sig.parameters.keys())
assert len(params) == 2, "pytorch_adder 应该有 2 个参数"
assert params[0] == "x", "第一个参数应该是 'x'"
assert params[1] == "y", "第二个参数应该是 'y'"