多节点编译

多节点编译#

对于多节点编译这种情况,我们生成了两个编译器,但由于它们具有相同的输入 (x),应该将它们合并为一个。

import set_env
/media/pc/data/lxw/ai/tvm
from tvm import relay
import tvm
from tvm.relay.testing.byoc import CcompilerAnnotator

def get_expr():
    x = relay.var("x", shape=(10, 10))
    w0 = relay.var("w0", shape=(10, 10))
    w1 = relay.var("w1", shape=(10, 10))
    w2 = relay.var("w2", shape=(10, 10))
    w3 = relay.var("w3", shape=(10, 10))
    w4 = relay.var("w4", shape=(10, 10))
    w5 = relay.var("w5", shape=(10, 10))
    w6 = relay.var("w6", shape=(10, 10))
    w7 = relay.var("w7", shape=(10, 10))

    z0 = relay.add(x, w0)
    p0 = relay.subtract(z0, w1)
    q0 = relay.multiply(p0, w2)

    z1 = relay.add(x, w3)
    p1 = relay.subtract(z1, w4)
    q1 = relay.multiply(p1, w5)

    z2 = relay.add(x, w6)
    q2 = relay.subtract(z2, w7)

    r = relay.concatenate((q0, q1, q2), axis=0)
    return relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)
mod = tvm.IRModule()
ann = CcompilerAnnotator()
mod["main"] = ann.visit(get_expr())
mod.show()
mod = relay.transform.PartitionGraph()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
def @main(%x: Tensor[(10, 10), float32], %w0: Tensor[(10, 10), float32], %w1: Tensor[(10, 10), float32], %w2: Tensor[(10, 10), float32], %w3: Tensor[(10, 10), float32], %w4: Tensor[(10, 10), float32], %w5: Tensor[(10, 10), float32], %w6: Tensor[(10, 10), float32], %w7: Tensor[(10, 10), float32]) {
  %0 = annotation.compiler_begin(%x, compiler="ccompiler");
  %1 = annotation.compiler_begin(%w0, compiler="ccompiler");
  %2 = add(%0, %1);
  %3 = annotation.compiler_begin(%w1, compiler="ccompiler");
  %4 = subtract(%2, %3);
  %5 = annotation.compiler_begin(%w2, compiler="ccompiler");
  %6 = multiply(%4, %5);
  %7 = annotation.compiler_begin(%x, compiler="ccompiler");
  %8 = annotation.compiler_begin(%w3, compiler="ccompiler");
  %9 = add(%7, %8);
  %10 = annotation.compiler_begin(%w4, compiler="ccompiler");
  %11 = subtract(%9, %10);
  %12 = annotation.compiler_begin(%w5, compiler="ccompiler");
  %13 = multiply(%11, %12);
  %14 = add(%x, %w6);
  %15 = annotation.compiler_end(%6, compiler="ccompiler");
  %16 = annotation.compiler_end(%13, compiler="ccompiler");
  %17 = subtract(%14, %w7);
  %18 = (%15, %16, %17);
  concatenate(%18)
}
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(30, 10), float32] {
  %0 = add(%x, %w6) /* ty=Tensor[(10, 10), float32] */;
  %1 = @tvmgen_default_ccompiler_main_0(%x, %w0, %w1, %w2) /* ty=Tensor[(10, 10), float32] */;
  %2 = @tvmgen_default_ccompiler_main_4(%x, %w3, %w4, %w5) /* ty=Tensor[(10, 10), float32] */;
  %3 = subtract(%0, %w7) /* ty=Tensor[(10, 10), float32] */;
  %4 = (%1, %2, %3) /* ty=(Tensor[(10, 10), float32], Tensor[(10, 10), float32], Tensor[(10, 10), float32]) */;
  concatenate(%4) /* ty=Tensor[(30, 10), float32] */
}

def @tvmgen_default_ccompiler_main_0(%ccompiler_0_i0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_0_i1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_0_i2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_0_i3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, Compiler="ccompiler", Primitive=1, Inline=1, global_symbol="tvmgen_default_ccompiler_main_0") -> Tensor[(10, 10), float32] {
  %5 = add(%ccompiler_0_i0, %ccompiler_0_i1) /* ty=Tensor[(10, 10), float32] */;
  %6 = subtract(%5, %ccompiler_0_i2) /* ty=Tensor[(10, 10), float32] */;
  multiply(%6, %ccompiler_0_i3) /* ty=Tensor[(10, 10), float32] */
}

def @tvmgen_default_ccompiler_main_4(%ccompiler_4_i0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_4_i1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_4_i2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_4_i3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, Compiler="ccompiler", Primitive=1, Inline=1, global_symbol="tvmgen_default_ccompiler_main_4") -> Tensor[(10, 10), float32] {
  %7 = add(%ccompiler_4_i0, %ccompiler_4_i1) /* ty=Tensor[(10, 10), float32] */;
  %8 = subtract(%7, %ccompiler_4_i2) /* ty=Tensor[(10, 10), float32] */;
  multiply(%8, %ccompiler_4_i3) /* ty=Tensor[(10, 10), float32] */
}
import numpy as np
from tvm.relay.backend.runtime import Runtime
from tvm.relay.backend import te_compiler
from tvm.contrib.utils import tempdir

def update_lib(lib, source_dir="/media/pc/data/lxw/ai/tvm"):
    kwargs = {
        "options" : [
            "-O2", "-std=c++17", 
            f"-I{source_dir}/src/runtime/contrib", 
            f"-I{source_dir}/include",
            f"-I{source_dir}/3rdparty/dlpack/include",
            f"-I{source_dir}/3rdparty/dmlc-core/include",
        ]
    }
    tmp_path = tempdir()
    lib_name = "lib.so"
    lib_path = tmp_path.relpath(lib_name)
    lib.export_library(lib_path, fcompile=False, **kwargs)
    lib = tvm.runtime.load_module(lib_path)
    return lib

def check_result(
    mod,
    map_inputs,
    out_shape,
    result,
    tol=1e-5,
    target="llvm",
    device=tvm.cpu(),
    params=None,
    runtime=Runtime("cpp"),
):
    def check_vm_result():
        te_compiler.get().clear()
        with tvm.transform.PassContext(opt_level=3):
            exe = relay.vm.compile(mod, target=target, params=params)
        code, lib = exe.save()
        lib = update_lib(lib)
        exe = tvm.runtime.vm.Executable.load_exec(code, lib)
        vm = tvm.runtime.vm.VirtualMachine(exe, device)
        outs = vm.run(**map_inputs)
        outs = outs if isinstance(outs, tvm.runtime.container.ADT) else [outs]
        results = result if isinstance(result, list) else [result]
        for out, ref in zip(outs, results):
            tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol)
    check_vm_result()
x_data = np.random.rand(10, 10).astype("float32")
w_data = []
for _ in range(8):
    w_data.append(np.random.rand(10, 10).astype("float32"))

map_inputs = {f"w{i}": w_data[i] for i in range(8)}
map_inputs["x"] = x_data
params = None
targets = [("llvm", Runtime("cpp")), ("c", Runtime("crt", {"system-lib": True}))]
for tgt, rt in targets:
    check_result(
        mod,
        map_inputs,
        (30, 10),
        np.concatenate(
            (
                ((x_data + w_data[0]) - w_data[1]) * w_data[2],
                ((x_data + w_data[3]) - w_data[4]) * w_data[5],
                x_data + w_data[6] - w_data[7],
            ),
            axis=0,
        ),
        target=tgt,
        runtime=rt,
    )