注册插件#
tvm.register_extension()
用于将自定义类注册为 TVM(Tensor Virtual Machine)扩展类的函数。通过注册这个类可以作为 TVM 生成的函数的参数直接传递。以下是详细解读:
目的:将自定义类注册为 TVM 的扩展类,使其能够作为 TVM 生成的函数的参数。
核心要求:注册的类必须包含名为
_tvm_handle
的属性,用于返回表示句柄地址的整数值。
import tvm
from tvm import te
import numpy as np
@tvm.register_extension
class MyTensorView:
_tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE
def __init__(self, arr):
self.arr = arr
@property
def _tvm_handle(self):
return self.arr._tvm_handle
也可这样:
from dataclasses import dataclass
from typing import Sequence
import tvm
from tvm import te
import numpy as np
@tvm.register_extension
@dataclass
class MyTensorView:
arr: Sequence
_tvm_tcode: int = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE
@property
def _tvm_handle(self):
return self.arr._tvm_handle
DLTensor 兼容性#
DLTensor 兼容性 是指一个类或数据结构能够与 TVM 中的 DLTensor 类型无缝交互。DLTensor 是 TVM 中用于表示张量(Tensor)的核心数据结构,它与深度学习框架(如 PyTorch、TensorFlow)中的张量类似,但具有更高的灵活性和跨平台支持。
dtype = "int64"
n = te.var("n")
Ab = tvm.tir.decl_buffer((n,), dtype)
i = te.var("i")
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
f(aview)
np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))
aview, type(a)
(MyTensorView(arr=<tvm.nd.NDArray shape=(10,), cpu(0)>
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), _tvm_tcode=7),
tvm.runtime.ndarray.NDArray)
此时 MyTensorView
接受 tvm.nd.NDArray
作为输入,返回 tvm.nd.NDArray
作为输出。但是:
b = MyTensorView(np.zeros(10, dtype=dtype))
b, type(b)
(MyTensorView(arr=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), _tvm_tcode=7),
__main__.MyTensorView)