Relax VM 中带有内存作用域(scope)的朴素内存分配器#
TVM Relax 虚拟机(VM)中带有明确内存作用域指定的存储分配功能。
内存作用域是 TVM 中用于标识不同内存区域的一种机制,可以用于区分全局内存、共享内存、纹理内存等不同类型的存储空间。
import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
@I.ir_module
class Module:
"""
IR 模块定义,包含两个函数:
1. 一个 TIR 原始函数 add,执行矩阵加法操作
2. 一个 Relax 主函数 main,展示存储分配和张量操作
"""
@T.prim_func
def add(
arg0: T.Buffer((2, 2), "float32"),
arg1: T.Buffer((2, 2), "float32"),
output: T.Buffer((2, 2), "float32"),
):
"""
TIR 原始函数,实现二维矩阵加法操作
参数:
arg0: 第一个输入张量(2x2 float32)
arg1: 第二个输入张量(2x2 float32)
output: 输出张量(2x2 float32)
"""
T.func_attr({"operator_name": "relax.add"})
# 双重循环遍历矩阵的每个元素
for ax0 in range(2):
for ax1 in range(2):
with T.block("T_add"):
# 定义空间轴索引
v_ax0 = T.axis.spatial(2, ax0)
v_ax1 = T.axis.spatial(2, ax1)
# 声明读取和写入的内存区域
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output[v_ax0, v_ax1])
# 执行加法运算
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
@R.function(pure=False)
def main(x: R.Tensor((2, 2), dtype="float32")):
"""
Relax 主函数,展示如何使用带有内存作用域的存储分配API
参数:
x: 输入张量(2x2 float32)
返回值:
处理后的张量(2x2 float32)
"""
cls = Module
# 分配存储,指定大小、设备索引、数据类型和内存作用域
# storage_scope="global"表示使用全局内存
storage = R.vm.alloc_storage(
R.shape([2 * 2]), runtime_device_index=0, dtype="float32", storage_scope="global"
)
# 从已分配的存储中创建张量视图,指定偏移量、形状和数据类型
alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([2, 2]), dtype="float32")
# 调用TIR函数执行计算,将结果存入分配的张量中
_: R.Tuple = cls.add(x, x, alloc)
# 将分配的张量设置为输出
out: R.Tensor((2, 2), dtype="float32") = alloc
return out
def test_alloc_storage_with_scope_global():
"""
测试带有'global'内存作用域的存储分配功能
此测试验证了在 Relax VM 中使用带有明确内存作用域指定的存储分配API,
并确保计算结果的正确性。特别测试了使用'naive'内存配置时的行为。
"""
# 生成随机测试数据
arg0 = np.random.uniform(size=(2, 2)).astype(np.float32)
# 计算预期结果(输入矩阵与自身相加)
output_ref = arg0 + arg0
# 使用前面定义的模块
mod = Module
# 设置目标为LLVM(在CPU上运行)
target = "llvm"
# 使用优化级别3构建模块
with tvm.transform.PassContext(opt_level=3):
lib = tvm.relax.build(mod, target=target, exec_mode="compiled")
# 获取CPU设备
dev = tvm.cpu()
# 关键测试点:使用'naive'内存配置创建虚拟机运行时
vm_rt = relax.VirtualMachine(lib, dev, memory_cfg="naive")
# 将NumPy数组转换为TVM NDArray并设置为输入
x = tvm.nd.array(arg0, dev)
vm_rt.set_input("main", x)
# 调用有状态函数执行计算
vm_rt.invoke_stateful("main")
# 获取输出并转换为NumPy数组
output = vm_rt.get_outputs("main").numpy()
# 验证计算结果是否符合预期
tvm.testing.assert_allclose(output_ref, output)