注册插件

注册插件#

参考:将 TVM 集成到您的项目中

tvm.register_extension() 用于将自定义类注册为 TVM(Tensor Virtual Machine)扩展类的函数。通过注册这个类可以作为 TVM 生成的函数的参数直接传递。以下是详细解读:

  • 目的:将自定义类注册为 TVM 的扩展类,使其能够作为 TVM 生成的函数的参数。

  • 核心要求:注册的类必须包含名为 _tvm_handle 的属性,用于返回表示句柄地址的整数值。

import tvm
from tvm import te
import numpy as np

@tvm.register_extension
class MyTensorView:
    _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE

    def __init__(self, arr):
        self.arr = arr

    @property
    def _tvm_handle(self):
        return self.arr._tvm_handle

也可这样:

from dataclasses import dataclass
from typing import Sequence
import tvm
from tvm import te
import numpy as np

@tvm.register_extension
@dataclass
class MyTensorView:
    arr: Sequence
    _tvm_tcode: int = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE

    @property
    def _tvm_handle(self):
        return self.arr._tvm_handle

DLTensor 兼容性#

DLTensor 兼容性 是指一个类或数据结构能够与 TVM 中的 DLTensor 类型无缝交互。DLTensor 是 TVM 中用于表示张量(Tensor)的核心数据结构,它与深度学习框架(如 PyTorch、TensorFlow)中的张量类似,但具有更高的灵活性和跨平台支持。

dtype = "int64"
n = te.var("n")
Ab = tvm.tir.decl_buffer((n,), dtype)
i = te.var("i")
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
    A[i + 1] = A[i] + 1
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
f(aview)
np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))
aview, type(a)
(MyTensorView(arr=<tvm.nd.NDArray shape=(10,), cpu(0)>
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), _tvm_tcode=7),
 tvm.runtime.ndarray.NDArray)

此时 MyTensorView 接受 tvm.nd.NDArray 作为输入,返回 tvm.nd.NDArray 作为输出。但是:

b = MyTensorView(np.zeros(10, dtype=dtype))
b, type(b)
(MyTensorView(arr=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), _tvm_tcode=7),
 __main__.MyTensorView)