多节点编译#
对于多节点编译这种情况,我们生成了两个编译器,但由于它们具有相同的输入 (x
),应该将它们合并为一个。
import set_env
/media/pc/data/lxw/ai/tvm
from tvm import relay
import tvm
from tvm.relay.testing.byoc import CcompilerAnnotator
def get_expr():
x = relay.var("x", shape=(10, 10))
w0 = relay.var("w0", shape=(10, 10))
w1 = relay.var("w1", shape=(10, 10))
w2 = relay.var("w2", shape=(10, 10))
w3 = relay.var("w3", shape=(10, 10))
w4 = relay.var("w4", shape=(10, 10))
w5 = relay.var("w5", shape=(10, 10))
w6 = relay.var("w6", shape=(10, 10))
w7 = relay.var("w7", shape=(10, 10))
z0 = relay.add(x, w0)
p0 = relay.subtract(z0, w1)
q0 = relay.multiply(p0, w2)
z1 = relay.add(x, w3)
p1 = relay.subtract(z1, w4)
q1 = relay.multiply(p1, w5)
z2 = relay.add(x, w6)
q2 = relay.subtract(z2, w7)
r = relay.concatenate((q0, q1, q2), axis=0)
return relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)
mod = tvm.IRModule()
ann = CcompilerAnnotator()
mod["main"] = ann.visit(get_expr())
mod.show()
mod = relay.transform.PartitionGraph()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
def @main(%x: Tensor[(10, 10), float32], %w0: Tensor[(10, 10), float32], %w1: Tensor[(10, 10), float32], %w2: Tensor[(10, 10), float32], %w3: Tensor[(10, 10), float32], %w4: Tensor[(10, 10), float32], %w5: Tensor[(10, 10), float32], %w6: Tensor[(10, 10), float32], %w7: Tensor[(10, 10), float32]) {
%0 = annotation.compiler_begin(%x, compiler="ccompiler");
%1 = annotation.compiler_begin(%w0, compiler="ccompiler");
%2 = add(%0, %1);
%3 = annotation.compiler_begin(%w1, compiler="ccompiler");
%4 = subtract(%2, %3);
%5 = annotation.compiler_begin(%w2, compiler="ccompiler");
%6 = multiply(%4, %5);
%7 = annotation.compiler_begin(%x, compiler="ccompiler");
%8 = annotation.compiler_begin(%w3, compiler="ccompiler");
%9 = add(%7, %8);
%10 = annotation.compiler_begin(%w4, compiler="ccompiler");
%11 = subtract(%9, %10);
%12 = annotation.compiler_begin(%w5, compiler="ccompiler");
%13 = multiply(%11, %12);
%14 = add(%x, %w6);
%15 = annotation.compiler_end(%6, compiler="ccompiler");
%16 = annotation.compiler_end(%13, compiler="ccompiler");
%17 = subtract(%14, %w7);
%18 = (%15, %16, %17);
concatenate(%18)
}
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %w7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(30, 10), float32] {
%0 = add(%x, %w6) /* ty=Tensor[(10, 10), float32] */;
%1 = @tvmgen_default_ccompiler_main_0(%x, %w0, %w1, %w2) /* ty=Tensor[(10, 10), float32] */;
%2 = @tvmgen_default_ccompiler_main_4(%x, %w3, %w4, %w5) /* ty=Tensor[(10, 10), float32] */;
%3 = subtract(%0, %w7) /* ty=Tensor[(10, 10), float32] */;
%4 = (%1, %2, %3) /* ty=(Tensor[(10, 10), float32], Tensor[(10, 10), float32], Tensor[(10, 10), float32]) */;
concatenate(%4) /* ty=Tensor[(30, 10), float32] */
}
def @tvmgen_default_ccompiler_main_0(%ccompiler_0_i0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_0_i1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_0_i2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_0_i3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, Compiler="ccompiler", Primitive=1, Inline=1, global_symbol="tvmgen_default_ccompiler_main_0") -> Tensor[(10, 10), float32] {
%5 = add(%ccompiler_0_i0, %ccompiler_0_i1) /* ty=Tensor[(10, 10), float32] */;
%6 = subtract(%5, %ccompiler_0_i2) /* ty=Tensor[(10, 10), float32] */;
multiply(%6, %ccompiler_0_i3) /* ty=Tensor[(10, 10), float32] */
}
def @tvmgen_default_ccompiler_main_4(%ccompiler_4_i0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_4_i1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_4_i2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %ccompiler_4_i3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, Compiler="ccompiler", Primitive=1, Inline=1, global_symbol="tvmgen_default_ccompiler_main_4") -> Tensor[(10, 10), float32] {
%7 = add(%ccompiler_4_i0, %ccompiler_4_i1) /* ty=Tensor[(10, 10), float32] */;
%8 = subtract(%7, %ccompiler_4_i2) /* ty=Tensor[(10, 10), float32] */;
multiply(%8, %ccompiler_4_i3) /* ty=Tensor[(10, 10), float32] */
}
import numpy as np
from tvm.relay.backend.runtime import Runtime
from tvm.relay.backend import te_compiler
from tvm.contrib.utils import tempdir
def update_lib(lib, source_dir="/media/pc/data/lxw/ai/tvm"):
kwargs = {
"options" : [
"-O2", "-std=c++17",
f"-I{source_dir}/src/runtime/contrib",
f"-I{source_dir}/include",
f"-I{source_dir}/3rdparty/dlpack/include",
f"-I{source_dir}/3rdparty/dmlc-core/include",
]
}
tmp_path = tempdir()
lib_name = "lib.so"
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.runtime.load_module(lib_path)
return lib
def check_result(
mod,
map_inputs,
out_shape,
result,
tol=1e-5,
target="llvm",
device=tvm.cpu(),
params=None,
runtime=Runtime("cpp"),
):
def check_vm_result():
te_compiler.get().clear()
with tvm.transform.PassContext(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
exe = tvm.runtime.vm.Executable.load_exec(code, lib)
vm = tvm.runtime.vm.VirtualMachine(exe, device)
outs = vm.run(**map_inputs)
outs = outs if isinstance(outs, tvm.runtime.container.ADT) else [outs]
results = result if isinstance(result, list) else [result]
for out, ref in zip(outs, results):
tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol)
check_vm_result()
x_data = np.random.rand(10, 10).astype("float32")
w_data = []
for _ in range(8):
w_data.append(np.random.rand(10, 10).astype("float32"))
map_inputs = {f"w{i}": w_data[i] for i in range(8)}
map_inputs["x"] = x_data
params = None
targets = [("llvm", Runtime("cpp")), ("c", Runtime("crt", {"system-lib": True}))]
for tgt, rt in targets:
check_result(
mod,
map_inputs,
(30, 10),
np.concatenate(
(
((x_data + w_data[0]) - w_data[1]) * w_data[2],
((x_data + w_data[3]) - w_data[4]) * w_data[5],
x_data + w_data[6] - w_data[7],
),
axis=0,
),
target=tgt,
runtime=rt,
)