register_func() 注册全局函数

register_func() 注册全局函数#

import tvm._ffi

tvm._ffi.register_func??
Hide code cell output
Signature: tvm._ffi.register_func(func_name, f=None, override=False)
Source:   
def register_func(func_name, f=None, override=False):
    """Register global function

    Parameters
    ----------
    func_name : str or function
        The function name

    f : function, optional
        The function to be registered.

    override: boolean optional
        Whether override existing entry.

    Returns
    -------
    fregister : function
        Register function if f is not specified.

    Examples
    --------
    The following code registers my_packed_func as global function.
    Note that we simply get it back from global function table to invoke
    it from python side. However, we can also invoke the same function
    from C++ backend, or in the compiled TVM code.

    .. code-block:: python

      targs = (10, 10.0, "hello")
      @tvm.register_func
      def my_packed_func(*args):
          assert(tuple(args) == targs)
          return 10
      # Get it out from global function table
      f = tvm.get_global_func("my_packed_func")
      assert isinstance(f, tvm.PackedFunc)
      y = f(*targs)
      assert y == 10
    """
    if callable(func_name):
        f = func_name
        func_name = f.__name__

    if not isinstance(func_name, str):
        raise ValueError("expect string function name")

    ioverride = ctypes.c_int(override)

    def register(myf):
        """internal register function"""
        if not isinstance(myf, PackedFuncBase):
            myf = convert_to_tvm_func(myf)
        check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride))
        return myf

    if f:
        return register(f)
    return register
File:      /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/registry.py
Type:      function

register_func() 用于注册全局函数。

示例:

from tvm.runtime.packed_func import PackedFunc
targs = (10, 10.0, "hello")
@tvm._ffi.register_func
def my_packed_func(*args):
    assert(tuple(args) == targs)
    return 10
# 从全局函数表中获取它
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, PackedFunc)
y = f(*targs)
assert y == 10