FFI 张量#
import pytest
try:
import torch
except ImportError:
torch = None
from tvm import ffi as tvm_ffi
import numpy as np
张量属性#
data = np.zeros((10, 8, 4, 2), dtype="int16")
if not hasattr(data, "__dlpack__"):
raise
x = tvm_ffi.from_dlpack(data)
assert isinstance(x, tvm_ffi.NDArray)
assert x.shape == (10, 8, 4, 2)
assert x.dtype == tvm_ffi.dtype("int16")
assert x.device.device_type == tvm_ffi.Device.kDLCPU
assert x.device.device_id == 0
x2 = np.from_dlpack(x)
np.testing.assert_equal(x2, data)
张量形状#
shape = tvm_ffi.Shape((10, 8, 4, 2))
assert isinstance(shape, tvm_ffi.Shape)
assert shape == (10, 8, 4, 2)
fecho = tvm_ffi.convert(lambda x: x)
shape2 = fecho(shape)
assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
assert isinstance(shape2, tvm_ffi.Shape)
assert isinstance(shape2, tuple)
shape3 = tvm_ffi.convert(shape)
assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
assert isinstance(shape3, tvm_ffi.Shape)
auto_dlpack#
def check(x, y):
assert isinstance(y, tvm_ffi.NDArray)
assert y.shape == (128,)
assert y.dtype == tvm_ffi.dtype("int64")
assert y.device.device_type == tvm_ffi.Device.kDLCPU
assert y.device.device_id == 0
x2 = torch.from_dlpack(y)
np.testing.assert_equal(x2.numpy(), x.numpy())
x = torch.arange(128)
fecho = tvm_ffi.get_global_func("testing.echo")
y = fecho(x)
check(x, y)
# pass in list of tensors
y = fecho([x])
check(x, y[0])