TVM 拓展:Python 调用 C++#

下面逐步揭开 TVM 中 C++/C 与 Python 交互的机制。

在 C++ 中定义加法算子:

#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>

using namespace tvm::runtime;

namespace tvm_ext {
TVM_REGISTER_GLOBAL("tvm_ext.sym_add").set_body([](TVMArgs args, TVMRetValue* rv) {
  tvm::tir::Var a = args[0];
  tvm::tir::Var b = args[1];
  *rv = a + b;
});
} // namespace tvm_ext
# 部署 TVM 模块 的 Makefile 样例
# =================================================================================
IDIR = include
TVM_ROOT=$(shell cd /media/pc/data/lxw/ai/tvm; pwd)
PKG_CXXFLAGS = -std=c++17 -O2 -fPIC\
	-I${TVM_ROOT}/include\
	-I${TVM_ROOT}/3rdparty/dmlc-core/include\
	-I${TVM_ROOT}/3rdparty/dlpack/include\
	-I${IDIR}\
	-DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\>

PKG_LDFLAGS = -ldl -pthread
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
	PKG_LDFLAGS += -undefined dynamic_lookup
endif

.PHONY: clean all

all: outputs/libs/libtvm_ext.so

# 定制 tvm 拓展运行时
# =================================================================================
outputs/libs/libtvm_ext.so: src/tvm_ext.cc
	@mkdir -p $(@D)
	$(CXX) $(PKG_CXXFLAGS) -shared -o $@ $^ $(PKG_LDFLAGS) -L${TVM_ROOT}/build

clean:
	rm -rf outputs/*

编译:

%%bash
cd cpp/sym_add
make clean
make
rm -rf outputs/*
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build

Python 端加载 C++ 端动态库#

可以使用 ctypes 加载动态库:

import ctypes

# 作为全局加载,使全局外部符号对其他 dll 可见。
_LIB = ctypes.CDLL("cpp/sym_add/outputs/libs/libtvm_ext.so", ctypes.RTLD_GLOBAL)
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
Cell In[2], line 4
      1 import ctypes
      3 # 作为全局加载,使全局外部符号对其他 dll 可见。
----> 4 _LIB = ctypes.CDLL("cpp/sym_add/outputs/libs/libtvm_ext.so", ctypes.RTLD_GLOBAL)

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/ctypes/__init__.py:379, in CDLL.__init__(self, name, mode, handle, use_errno, use_last_error, winmode)
    376 self._FuncPtr = _FuncPtr
    378 if handle is None:
--> 379     self._handle = _dlopen(self._name, mode)
    380 else:
    381     self._handle = handle

OSError: cpp/sym_add/outputs/libs/libtvm_ext.so: undefined symbol: _ZNK3tvm7runtime6Object11DerivedFromEj

加载失败,是由于 libtvm_ext.so 是在 libtvm.so 基础上拓展的,故而需要先提前加载 libtvm.so,或者直接 import tvm

import set_env
import tvm
import ctypes

# 作为全局加载,使全局外部符号对其他 dll 可见。
_LIB = ctypes.CDLL("cpp/sym_add/outputs/libs/libtvm_ext.so", ctypes.RTLD_GLOBAL)

加载动态库,也可以直接使用 load_lib()

import set_env
from tvm_book.tvm_ext.libinfo import load_lib

_LIB_EXT, _LIB_EXT_NAME = load_lib(name="libtvm_ext.so", search_path=["cpp/sym_add/outputs/libs"])

回调 C++ 函数:

import tvm
sym_add = tvm.get_global_func("tvm_ext.sym_add")

测试:

from tvm import te
x = te.var("x")
y = te.var("y")
z = sym_add(x, y)
assert z.a == x and z.b == y
print(z)
x + y

这些调用细节可以借助 FFI 机制进行隐藏。

使用 tvm._ffi._init_api() 管理 TVM 插件#

import set_env
from tvm_book.tvm_ext.libinfo import load_lib

_LIB_EXT, _LIB_EXT_NAME = load_lib(name="libtvm_ext.so", search_path=["cpp/sym_add/outputs/libs"])
import tvm

tvm._ffi._init_api("tvm_ext", __name__)

下面便可以直接使用 tvm_ext 下的函数了:

sym_add
<tvm.runtime.packed_func.PackedFunc at 0x7fee24b68ad0>

其他 C++ 打包函数的例子#

加法偏函数#

#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

using namespace tvm::runtime;

namespace tvm_ext {
TVM_REGISTER_GLOBAL("tvm_ext.bind_add").set_body([](TVMArgs args_, TVMRetValue* rv_) {
  PackedFunc pf = args_[0];
  int b = args_[1];
  *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue* rv) { *rv = pf(b, args[0]); });
});

} // namespace tvm_ext

编译:

%%bash
cd cpp/bind_add
make clean
make
rm -rf outputs/*
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build
import set_env
from tvm_book.tvm_ext.libinfo import load_lib
import tvm
_LIB_EXT, _LIB_EXT_NAME = load_lib(name="libtvm_ext.so", search_path=["cpp/bind_add/outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)
bind_add
<tvm.runtime.packed_func.PackedFunc at 0x7fee24b63ed0>
def add(a, b):
    return a + b

f = bind_add(add, 7)
assert f(2) == 9

C++ 外部设备的例子#

#include <tvm/runtime/device_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
// #include <tvm/tir/op.h>

using namespace tvm::runtime;

namespace tvm_ext {
TVM_REGISTER_GLOBAL("device_api.ext_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
  *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});

} // namespace tvm_ext

编译:

%%bash
cd cpp/device_api
make clean
make
rm -rf outputs/*
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build
import set_env
import tvm
from tvm_book.tvm_ext.libinfo import load_lib

_LIB_EXT, _LIB_EXT_NAME = load_lib(name="libtvm_ext.so", search_path=["cpp/device_api/outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)
import numpy as np
from tvm import te
n = 10
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda *i: A(*i) + 1.0, name="B")
s = te.create_schedule(B.op)

def check_llvm():
    f = tvm.build(s, [A, B], tvm.target.Target("ext_dev", "llvm"))
    dev = tvm.ext_dev(0)
    # launch the kernel.
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
    f(a, b)
    np.testing.assert_allclose(b.numpy(), a.numpy() + 1)

check_llvm()

回调 C++ 端外部函数#

#include <tvm/runtime/registry.h>

// 暴露给运行时的外部函数
extern "C" float TVMTestAddOne(float y) { return y + 1; }

编译:

%%bash
cd cpp/extern_func
make clean
make
rm -rf outputs/*
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build
import set_env
import tvm
from tvm_book.tvm_ext.libinfo import load_lib

_LIB_EXT, _LIB_EXT_NAME = load_lib(name="libtvm_ext.so", search_path=["cpp/extern_func/outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)
import numpy as np
from tvm import te
n = 10
A = te.placeholder((n,), name="A")
B = te.compute(
    (n,), lambda *i: tvm.tir.call_extern("float32", "TVMTestAddOne", A(*i)), name="B"
)
s = te.create_schedule(B.op)

def check_llvm():
    f = tvm.build(s, [A, B], "llvm")
    dev = tvm.cpu(0)
    # launch the kernel.
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
    f(a, b)
    np.testing.assert_allclose(b.numpy(), a.numpy() + 1)

check_llvm()

提取外部 C++ 函数#

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>

using namespace tvm::runtime;
// 这个回调方法允许扩展,使 TVM 能够提取。
// 当想要使用仅包含头文件的最小版本的 TVM 运行时,这种方法会很有帮助。
extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) {
  const PackedFunc& fregister = GetRef<PackedFunc>(static_cast<PackedFuncObj*>(pregister));
  // 等价于 const PackedFunc& fregister = *static_cast<PackedFunc*>(pregister);
  auto mul = [](TVMArgs args, TVMRetValue* rv) {
    int x = args[0];
    int y = args[1];
    *rv = x * y;
  };
  fregister("mul", PackedFunc(mul));
  return 0;
}

编译:

%%bash
cd cpp/mini_runtime
make clean
make
rm -rf outputs/*
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build
import set_env
import tvm
from tvm_book.tvm_ext.libinfo import load_lib

_LIB, _LIB_NAME = load_lib(name="libtvm_ext.so", search_path=["cpp/mini_runtime/outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)
fdict = tvm._ffi.registry.extract_ext_funcs(_LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12

汇总 TVM 插件测试 demo#

将上述插件集中到:

extension/src
    testing/
        _make.cc
        bind_add.cc
        device_api.cc
        extern_func.cc
        mini_runtime.cc
        sym_add.cc
    tvm_ext.cc

编译:

%%bash
make clean
make
rm -rf outputs/*
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_ext.so src/tvm_ext.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_plugin_module.so src/plugin_module.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build
import set_env
from tvm_book.tvm_ext.libinfo import load_lib
_LIB, _LIB_NAME = load_lib(name="libtvm_ext.so", search_path=["outputs/libs"])
from tvm_ext.testing import demo

比如:

from tvm import te
a = te.var("x")
b = te.var("y")
c = demo.sym_add(a, b)
assert c.a == a and c.b == b
print(c)
x + y