FFI 函数#

import gc
import ctypes
import numpy as np
from tvm import ffi as tvm_ffi

echo#

fecho = tvm_ffi.get_global_func("testing.echo")
assert isinstance(fecho, tvm_ffi.Function)
# test each type
assert fecho(None) is None

# test bool
bool_result = fecho(True)
assert isinstance(bool_result, bool)
assert bool_result is True
bool_result = fecho(False)
assert isinstance(bool_result, bool)
assert bool_result is False

# test int/float
assert fecho(1) == 1
assert fecho(1.2) == 1.2

# test str
str_result = fecho("hello")
assert isinstance(str_result, str)
assert str_result == "hello"
assert isinstance(str_result, tvm_ffi.String)

# test bytes
bytes_result = fecho(b"abc")
assert isinstance(bytes_result, bytes)
assert bytes_result == b"abc"
assert isinstance(bytes_result, tvm_ffi.Bytes)

# test dtype
dtype_result = fecho(tvm_ffi.dtype("float32"))
assert isinstance(dtype_result, tvm_ffi.dtype)
assert dtype_result == tvm_ffi.dtype("float32")

# test device
device_result = fecho(tvm_ffi.device("cuda:1"))
assert isinstance(device_result, tvm_ffi.Device)
assert device_result.device_type == tvm_ffi.Device.kDLCUDA
assert device_result.device_id == 1
assert str(device_result) == "cuda:1"
assert device_result.__repr__() == "device(type='cuda', index=1)"

# test c_void_p
c_void_p_result = fecho(ctypes.c_void_p(0x12345678))
assert isinstance(c_void_p_result, ctypes.c_void_p)
assert c_void_p_result.value == 0x12345678

# test function: aka object
fadd = tvm_ffi.convert(lambda a, b: a + b)
fadd1 = fecho(fadd)
assert fadd1(1, 2) == 3
assert fadd1.same_as(fadd)

def check_ndarray():
    np_data = np.arange(10, dtype="int32")
    if not hasattr(np_data, "__dlpack__"):
        return
    # test NDArray
    x = tvm_ffi.from_dlpack(np_data)
    assert isinstance(x, tvm_ffi.NDArray)
    nd_result = fecho(x)
    assert isinstance(nd_result, tvm_ffi.NDArray)
    assert nd_result.shape == (10,)
    assert nd_result.dtype == tvm_ffi.dtype("int32")
    assert nd_result.device.device_type == tvm_ffi.Device.kDLCPU
    assert nd_result.device.device_id == 0

check_ndarray()

返回原始字符#

assert tvm_ffi.convert(lambda: "hello")() == "hello"
assert tvm_ffi.convert(lambda: b"hello")() == b"hello"
assert tvm_ffi.convert(lambda: bytearray(b"hello"))() == b"hello"

Python 函数转换#

def add(a, b):
    return a + b

fadd = tvm_ffi.convert(add)
assert isinstance(fadd, tvm_ffi.Function)
assert fadd(1, 2) == 3

def fapply(f, *args):
    return f(*args)

fapply = tvm_ffi.convert(fapply)
assert fapply(add, 1, 3.3) == 4.3

注册全局函数#

@tvm_ffi.register_func("mytest.echo")
def echo(x):
    return x

f = tvm_ffi.get_global_func("mytest.echo")
assert f.same_as(echo)
assert f(1) == 1

assert "mytest.echo" in tvm_ffi.registry.list_global_func_names()

tvm_ffi.registry.remove_global_func("mytest.echo")
assert "mytest.echo" not in tvm_ffi.registry.list_global_func_names()
assert tvm_ffi.get_global_func("mytest.echo", allow_missing=True) is None

右值引用#

use_count = tvm_ffi.get_global_func("testing.object_use_count")

def callback(x, expected_count):
    # The use count of TVM FFI objects is decremented as part of
    # `ObjectRef.__del__`, which runs when the Python object is
    # destructed.  However, Python object destruction is not
    # deterministic, and even CPython's reference-counting is
    # considered an implementation detail.  Therefore, to ensure
    # correct results from this test, `gc.collect()` must be
    # explicitly called.
    gc.collect()
    assert expected_count == use_count(x)
    return x._move()

f = tvm_ffi.convert(callback)

def check0():
    x = tvm_ffi.convert([1, 2])
    assert use_count(x) == 1
    f(x, 2)
    y = f(x._move(), 1)
    assert x.__ctypes_handle__().value == None

def check1():
    x = tvm_ffi.convert([1, 2])
    assert use_count(x) == 1
    y = f(x, 2)
    z = f(x._move(), 2)
    assert x.__ctypes_handle__().value == None
    assert y.__ctypes_handle__().value is not None

check0()
check1()