AttachExternModules

AttachExternModules#

import subprocess
import tempfile
from pathlib import Path

import numpy as np

import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import spec
from tvm.relax.transform import AttachExternModules

def _compile_cc(src: Path, dst: Path):
    # pylint: disable=import-outside-toplevel
    from tvm._ffi.base import py_str
    from tvm._ffi.libinfo import find_include_path

    # pylint: enable=import-outside-toplevel

    cmd = ["g++", str(src)]
    for include_path in find_include_path():
        cmd += ["-I", include_path]
    cmd += [
        "-DDMLC_USE_FOPEN64=0",
        "-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>",
        "-c",
        "-fPIC", # 生成位置无关代码 (-fPIC) 便于后续动态链接
        "-o",
        str(dst),
    ]
    with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc:
        (out, _) = proc.communicate()
        if proc.returncode != 0:
            msg = "Compilation error:\n"
            msg += py_str(out)
            msg += "\nCommand line: " + " ".join(cmd)
            raise RuntimeError(msg)

推断测试:

def _infer_scalar_add(x, y):  # pylint: disable=invalid-name
    assert isinstance(x, nn.Tensor)
    assert isinstance(y, nn.Tensor)
    assert x.ndim == 0 and x.dtype == "float32"
    assert y.ndim == 0 and y.dtype == "float32"
    return nn.Tensor.placeholder(shape=(), dtype="float32")


def _infer_test_sym(a, b):  # pylint: disable=invalid-name
    def _var_equal(a, b):  # pylint: disable=invalid-name
        return tvm.ir.structural_equal(a, b, map_free_vars=True)

    assert isinstance(a, nn.Tensor)
    assert isinstance(b, nn.Tensor)
    assert a.ndim == 3 and a.dtype == "float32"  # [x, y, 1]
    assert b.ndim == 3 and b.dtype == "float32"  # [y, z, 5]
    x, y, z = a.shape[0], b.shape[0], b.shape[1]  # pylint: disable=invalid-name
    assert _var_equal(a.shape[0], x)
    assert _var_equal(a.shape[1], y)
    assert a.shape[2] == 1
    assert _var_equal(b.shape[0], y)
    assert _var_equal(b.shape[1], z)
    assert b.shape[2] == 5
    return nn.Tensor.placeholder(shape=(x, y, z, 9), dtype="float32")


def _test_scalar_add(func):
    # pylint: disable=invalid-name
    x = tvm.nd.array(np.array(1.0).astype("float32"))
    y = tvm.nd.array(np.array(3.0).astype("float32"))
    z = func(x, y).numpy()
    # pylint: enable=invalid-name
    assert z.ndim == 0
    assert z.dtype == "float32"
    assert float(z) == 4.0


def _test_infer_sym(func, x, y, z):  # pylint: disable=invalid-name
    # pylint: disable=invalid-name
    a = tvm.nd.array(np.random.uniform(size=(x, y, 1)).astype("float32"))
    b = tvm.nd.array(np.random.uniform(size=(y, z, 5)).astype("float32"))
    c = func(a, b).numpy()
    # pylint: enable=invalid-name
    assert c.shape == (x, y, z, 9)


def _check_ir_equality(mod):
    # pylint: disable=import-outside-toplevel
    from tvm.script import ir as I
    from tvm.script import relax as R
    from tvm.script import tir as T

    # pylint: enable=import-outside-toplevel

    @I.ir_module
    class ExpectedModule:
        @R.function
        def scalar_add(
            a: R.Tensor((), dtype="float32"), b: R.Tensor((), dtype="float32")
        ) -> R.Tensor((), dtype="float32"):
            R.func_attr({"num_input": 2})
            with R.dataflow():
                ext_scalar_add = R.call_dps_packed(
                    "ext_scalar_add", (a, b), out_sinfo=R.Tensor((), dtype="float32")
                )
                gv: R.Tensor((), dtype="float32") = ext_scalar_add
                R.output(gv)
            return gv

        @R.function
        def test_sym(
            a: R.Tensor(("x", "y", 1), dtype="float32"), b: R.Tensor(("y", "z", 5), dtype="float32")
        ) -> R.Tensor(("x", "y", "z", 9), dtype="float32"):
            x = T.int64()
            y = T.int64()
            z = T.int64()
            R.func_attr({"num_input": 2})
            with R.dataflow():
                ext_test_sym = R.call_dps_packed(
                    "ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype="float32")
                )
                gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym
                R.output(gv1)
            return gv1

    tvm.ir.assert_structural_equal(ExpectedModule, mod)

测试外部对象#

with tempfile.TemporaryDirectory() as temp_dir_str:
    path = Path(temp_dir_str) / "main.o"
    _compile_cc(
        src=Path("__file__").parent / "frontend_nn_extern_module.cc",
        dst=path,
    )

    class TestModule(nn.Module):
        def __init__(self):
            self.ext_mod = None

        def _get_ext_mod(self):
            if self.ext_mod is None:
                self.ext_mod = nn.ObjectModule(
                    {
                        "ext_scalar_add": _infer_scalar_add,
                        "ext_test_sym": _infer_test_sym,
                    },
                    path,
                )
                nn.add_extern(self.ext_mod)
            return self.ext_mod

        def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
            return self._get_ext_mod()["ext_scalar_add"](a, b)

        def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
            return self._get_ext_mod()["ext_test_sym"](a, b)

    mod, _, ext_mods = TestModule().export_tvm(
        spec={
            "scalar_add": {
                "a": spec.Tensor((), "float32"),
                "b": spec.Tensor((), "float32"),
            },
            "test_sym": {
                "a": spec.Tensor(("x", "y", 1), "float32"),
                "b": spec.Tensor(("y", "z", 5), "float32"),
            },
        },
        allow_extern=True,
    )
    _check_ir_equality(mod)
    mod = AttachExternModules(ext_mods)(mod)  # pylint: disable=not-callable
    compiled = tvm.runtime.relax_vm.VirtualMachine(
        tvm.compile(mod, target="llvm"),
        device=tvm.cpu(),
    )
    _test_scalar_add(compiled["scalar_add"])
    _test_infer_sym(compiled["test_sym"], x=3, y=4, z=2)

测试外部源码#

source = Path("__file__").parent / "frontend_nn_extern_module.cc"

class TestModule(nn.Module):
    def __init__(self):
        self.ext_mod = None

    def _get_ext_mod(self):
        if self.ext_mod is None:
            self.ext_mod = nn.SourceModule(
                {
                    "ext_scalar_add": _infer_scalar_add,
                    "ext_test_sym": _infer_test_sym,
                },
                source_code=source,
                source_format="cpp",
            )
            nn.add_extern(self.ext_mod)
        return self.ext_mod

    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self._get_ext_mod()["ext_scalar_add"](a, b)

    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self._get_ext_mod()["ext_test_sym"](a, b)

mod, _, ext_mods = TestModule().export_tvm(
    spec={
        "scalar_add": {
            "a": spec.Tensor((), "float32"),
            "b": spec.Tensor((), "float32"),
        },
        "test_sym": {
            "a": spec.Tensor(("x", "y", 1), "float32"),
            "b": spec.Tensor(("y", "z", 5), "float32"),
        },
    },
    allow_extern=True,
)
_check_ir_equality(mod)
mod = AttachExternModules(ext_mods)(mod)  # pylint: disable=not-callable
compiled = tvm.runtime.relax_vm.VirtualMachine(
    tvm.compile(mod, target="llvm"),
    device=tvm.cpu(),
)
_test_scalar_add(compiled["scalar_add"])
_test_infer_sym(compiled["test_sym"], x=3, y=4, z=2)