FFI 张量#

import pytest

try:
    import torch
except ImportError:
    torch = None

from tvm import ffi as tvm_ffi
import numpy as np

张量属性#

data = np.zeros((10, 8, 4, 2), dtype="int16")
if not hasattr(data, "__dlpack__"):
    raise 
x = tvm_ffi.from_dlpack(data)
assert isinstance(x, tvm_ffi.NDArray)
assert x.shape == (10, 8, 4, 2)
assert x.dtype == tvm_ffi.dtype("int16")
assert x.device.device_type == tvm_ffi.Device.kDLCPU
assert x.device.device_id == 0
x2 = np.from_dlpack(x)
np.testing.assert_equal(x2, data)

张量形状#

shape = tvm_ffi.Shape((10, 8, 4, 2))
assert isinstance(shape, tvm_ffi.Shape)
assert shape == (10, 8, 4, 2)

fecho = tvm_ffi.convert(lambda x: x)
shape2 = fecho(shape)
assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
assert isinstance(shape2, tvm_ffi.Shape)
assert isinstance(shape2, tuple)

shape3 = tvm_ffi.convert(shape)
assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
assert isinstance(shape3, tvm_ffi.Shape)

auto_dlpack#

def check(x, y):
    assert isinstance(y, tvm_ffi.NDArray)
    assert y.shape == (128,)
    assert y.dtype == tvm_ffi.dtype("int64")
    assert y.device.device_type == tvm_ffi.Device.kDLCPU
    assert y.device.device_id == 0
    x2 = torch.from_dlpack(y)
    np.testing.assert_equal(x2.numpy(), x.numpy())

x = torch.arange(128)
fecho = tvm_ffi.get_global_func("testing.echo")
y = fecho(x)
check(x, y)

# pass in list of tensors
y = fecho([x])
check(x, y[0])