编译外部库#
import set_env
/media/pc/data/lxw/ai/tvm
加载库:
import numpy as np
import tvm
from tvm import relay
from tvm.relay import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend.runtime import Runtime
from tvm.relay.backend import te_compiler
from tvm.contrib.utils import tempdir
def update_lib(lib, source_dir="/media/pc/data/lxw/ai/tvm"):
kwargs = {
"options" : [
"-O2", "-std=c++17",
f"-I{source_dir}/src/runtime/contrib",
f"-I{source_dir}/include",
f"-I{source_dir}/3rdparty/dlpack/include",
f"-I{source_dir}/3rdparty/dmlc-core/include",
]
}
tmp_path = tempdir()
lib_name = "lib.so"
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.runtime.load_module(lib_path)
return lib
def check_result(
mod,
map_inputs,
out_shape,
result,
tol=1e-5,
target="llvm",
device=tvm.cpu(),
params=None,
runtime=Runtime("cpp"),
):
def check_vm_result():
te_compiler.get().clear()
with tvm.transform.PassContext(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
exe = tvm.runtime.vm.Executable.load_exec(code, lib)
vm = tvm.runtime.vm.VirtualMachine(exe, device)
outs = vm.run(**map_inputs)
outs = outs if isinstance(outs, tvm.runtime.container.ADT) else [outs]
results = result if isinstance(result, list) else [result]
for out, ref in zip(outs, results):
np.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol)
check_vm_result()
以 z = x + y
为例子说明:
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
z = x + y
f = relay.Function([x, y], z)
mod = tvm.IRModule()
mod["main"] = f
mod.show()
def @main(%x: Tensor[(8, 8), float32], %y: Tensor[(8, 8), float32]) {
add(%x, %y)
}
编写简单的注解函数:
@relay.transform.function_pass(opt_level=0)
class MyAnnotator:
def transform_function(self, func, mod, dev):
class Annotator(ExprMutator):
def visit_call(self, call):
new_args = []
for arg in call.args:
ann = compiler_begin(self.visit(arg), "ccompiler")
new_args.append(ann)
new_call = relay.Call(call.op, new_args)
return compiler_end(new_call, "ccompiler")
return Annotator().visit(func)
将 +
的输入输入和输出进行注解:
mod = MyAnnotator()(mod)
mod.show()
def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(8, 8), float32] {
%0 = annotation.compiler_begin(%x, compiler="ccompiler") /* ty=Tensor[(8, 8), float32] */;
%1 = annotation.compiler_begin(%y, compiler="ccompiler") /* ty=Tensor[(8, 8), float32] */;
%2 = add(%0, %1) /* ty=Tensor[(8, 8), float32] */;
annotation.compiler_end(%2, compiler="ccompiler") /* ty=Tensor[(8, 8), float32] */
}
使用 PartitionGraph
分割计算图:
mod = relay.transform.PartitionGraph()(mod)
mod.show()
def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(8, 8), float32] {
@tvmgen_default_ccompiler_main_0(%x, %y) /* ty=Tensor[(8, 8), float32] */
}
def @tvmgen_default_ccompiler_main_0(%ccompiler_0_i0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %ccompiler_0_i1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler="ccompiler", Primitive=1, Inline=1, global_symbol="tvmgen_default_ccompiler_main_0") -> Tensor[(8, 8), float32] {
add(%ccompiler_0_i0, %ccompiler_0_i1) /* ty=Tensor[(8, 8), float32] */
}
验证结果一致性:
x_data = np.random.rand(8, 8).astype("float32")
y_data = np.random.rand(8, 8).astype("float32")
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
[15:56:02] /media/pc/data/lxw/ai/tvm/src/relay/backend/vm/compiler.cc:1199: All lowered functions have been build by BYOC -- generating an empty TVM module
注解白名单#
# 利用 pass 管理器编写简单的注释器白名单
@relay.transform.function_pass(opt_level=0)
class AllowedListAnnotator:
def __init__(self, op_list, compiler):
assert isinstance(op_list, (list, tuple, set))
self.op_list = op_list
self.compiler = compiler
def transform_function(self, func, mod, dev):
annotator = self
class Annotator(tvm.relay.ExprMutator):
def visit_call(self, call):
op_name = call.op.name
if op_name in annotator.op_list:
new_args = []
for arg in call.args:
ann = compiler_begin(super().visit(arg), annotator.compiler)
new_args.append(ann)
new_call = relay.Call(call.op, new_args, call.attrs, call.type_args)
return compiler_end(new_call, annotator.compiler)
else:
return super().visit_call(call)
return Annotator().visit(func)
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = tvm.IRModule()
mod["main"] = f
mod.show()
def @main(%x: Tensor[(8, 8), float32], %y: Tensor[(8, 8), float32]) {
%0 = add(%x, %y);
%1 = log(%0);
%2 = exp(%0);
%3 = (%1, %2);
concatenate(%3)
}
def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod["main"] = main
mod = transform.InferType()(mod)
return mod
def set_func_attr(func, compile_name, symbol_name):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", compile_name)
func = func.with_attr("global_symbol", symbol_name)
return func
def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod["main"] = main
mod = relay.transform.InferType()(mod)
return mod
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = tvm.IRModule()
mod["main"] = f
mod = AllowedListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
mod = relay.transform.PartitionGraph()(mod)
fused_mod = relay.transform.FuseOps(2)(mod)
expected_mod = expected()
assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
x_data = np.random.rand(8, 8).astype("float32")
y_data = np.random.rand(8, 8).astype("float32")
np_add = x_data + y_data
res = np.concatenate([np.log(np_add), np.exp(np_add)])
check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
expected_mod.show()
def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(16, 8), float32] {
%3 = @tvmgen_default_ccompiler_main_0(%x, %y) /* ty=Tensor[(8, 8), float32] */;
%4 = fn (%p0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Primitive=1) -> Tensor[(16, 8), float32] {
%0 = log(%p0) /* ty=Tensor[(8, 8), float32] */;
%1 = exp(%p0) /* ty=Tensor[(8, 8), float32] */;
%2 = (%0, %1) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */;
concatenate(%2) /* ty=Tensor[(16, 8), float32] */
} /* ty=fn (Tensor[(8, 8), float32]) -> Tensor[(16, 8), float32] */;
%4(%3) /* ty=Tensor[(16, 8), float32] */
}
def @tvmgen_default_ccompiler_main_0(%x0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Primitive=1, Inline=1, Compiler="ccompiler", global_symbol="tvmgen_default_ccompiler_main_0") -> Tensor[(8, 8), float32] {
add(%x0, %y0) /* ty=Tensor[(8, 8), float32] */
}
其他外部编译器支持#
def test_extern_compiler_sanitized_ops():
def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = set_func_attr(func, "unsanitary-name++", "tvmgen_default_unsanitary_name___main_0")
glb_0 = relay.GlobalVar("tvmgen_default_unsanitary_name___main_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod["main"] = main
mod = transform.InferType()(mod)
return mod
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = tvm.IRModule()
mod["main"] = f
mod = AllowedListAnnotator(["add", "subtract", "multiply"], "unsanitary-name++")(mod)
mod = transform.PartitionGraph()(mod)
fused_mod = transform.FuseOps(2)(mod)
expected_mod = expected()
assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
def test_extern_ccompiler_multiple_functions():
def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod["main"] = main
# define the second one
a = relay.var("a", shape=(16, 16))
b = relay.var("b", shape=(16, 16))
a0 = relay.var("a0", shape=(16, 16))
b0 = relay.var("b0", shape=(16, 16))
add = a0 + b0
# Function that uses C compiler
func = relay.Function([a0, b0], add)
func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_subfunction_0")
glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_subfunction_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [a, b])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(16, 16))
log = relay.log(p0)
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
sunfunction = relay.Function([a, b], fused_call)
mod["subfunction"] = sunfunction
mod = transform.InferType()(mod)
return mod
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = tvm.IRModule()
mod["main"] = f
# define second function
a = relay.var("a", shape=(16, 16))
b = relay.var("b", shape=(16, 16))
add = a + b
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f2 = relay.Function([a, b], concat)
mod["subfunction"] = f2
mod = AllowedListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
mod = transform.PartitionGraph()(mod)
fused_mod = transform.FuseOps(2)(mod)
expected_mod = expected()
assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
x_data = np.random.rand(8, 8).astype("float32")
y_data = np.random.rand(8, 8).astype("float32")
np_add = x_data + y_data
res = np.concatenate([np.log(np_add), np.exp(np_add)])
check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
def test_extern_ccompiler():
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
z = x + x
p = y * y
f = relay.Function([x, y], p - z)
x_data = np.random.rand(2, 2).astype("float32")
y_data = np.random.rand(2, 2).astype("float32")
mod = tvm.IRModule()
mod["main"] = f
mod = AllowedListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
mod = transform.PartitionGraph()(mod)
check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
class WholeGraphAnnotator(ExprMutator):
"""
An annotator that creates a compiler for an entire graph.
"""
def __init__(self, compiler):
super().__init__()
self.compiler = compiler
self.last_call = True
def visit_call(self, call):
curr_last = self.last_call
self.last_call = False
params = []
for arg in call.args:
param = super().visit(arg)
if isinstance(param, relay.expr.Var):
param = compiler_begin(param, self.compiler)
params.append(param)
new_call = relay.Call(call.op, params, call.attrs)
if curr_last:
new_call = compiler_end(new_call, self.compiler)
return new_call
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
def get_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(
data, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
)
depthwise_conv2d_2 = relay.nn.conv2d(
depthwise_conv2d_1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
return relay.Function([data, weight1], out)
func = get_func()
mod = tvm.IRModule()
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = relay.transform.PartitionGraph()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
def @main(%data: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %weight1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */) -> Tensor[(1, 32, 14, 14), float32] {
@tvmgen_default_dnnl_main_0(%data, %weight1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}
def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_0_i1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_0") -> Tensor[(1, 32, 14, 14), float32] {
%0 = nn.conv2d(%dnnl_0_i0, %dnnl_0_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */;
%1 = nn.conv2d(%0, %dnnl_0_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */;
add(%0, %1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}