# 函数 lifting

In [1]:
import set_env

/media/pc/data/lxw/ai/tvm


In [2]:
import numpy as np
import tvm
from tvm import relay
from tvm.relay import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend.runtime import Runtime
from tvm.relay.backend import te_compiler
from tvm.contrib.utils import tempdir

def update_lib(lib, source_dir="/media/pc/data/lxw/ai/tvm"):
    kwargs = {
        "options" : [
            "-O2", "-std=c++17", 
            f"-I{source_dir}/src/runtime/contrib", 
            f"-I{source_dir}/include",
            f"-I{source_dir}/3rdparty/dlpack/include",
            f"-I{source_dir}/3rdparty/dmlc-core/include",
        ]
    }
    tmp_path = tempdir()
    lib_name = "lib.so"
    lib_path = tmp_path.relpath(lib_name)
    lib.export_library(lib_path, fcompile=False, **kwargs)
    lib = tvm.runtime.load_module(lib_path)
    return lib

def check_result(
    mod,
    map_inputs,
    out_shape,
    result,
    tol=1e-5,
    target="llvm",
    device=tvm.cpu(),
    params=None,
    runtime=Runtime("cpp"),
):
    def check_vm_result():
        te_compiler.get().clear()
        with tvm.transform.PassContext(opt_level=3):
            exe = relay.vm.compile(mod, target=target, params=params)
        code, lib = exe.save()
        lib = update_lib(lib)
        exe = tvm.runtime.vm.Executable.load_exec(code, lib)
        vm = tvm.runtime.vm.VirtualMachine(exe, device)
        outs = vm.run(**map_inputs)
        outs = outs if isinstance(outs, tvm.runtime.container.ADT) else [outs]
        results = result if isinstance(result, list) else [result]
        for out, ref in zip(outs, results):
            np.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol)
    check_vm_result()


In [3]:
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
bn_mmean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
bn_mvar = relay.var("bn_var", relay.TensorType((16,), "float32"))

conv = relay.nn.conv2d(
    data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
)
bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar)

func = relay.Function(
    [data, weight, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn_output.astuple()
)
mod = tvm.IRModule()
mod["main"] = func
mod = relay.transform.InferType()(mod)
mod.show()

In [4]:
# 利用 pass 管理器编写简单的注释器白名单
@relay.transform.function_pass(opt_level=0)
class AllowedListAnnotator:
    def __init__(self, op_list, compiler):
        assert isinstance(op_list, (list, tuple, set))
        self.op_list = op_list
        self.compiler = compiler

    def transform_function(self, func, mod, dev):

        annotator = self

        class Annotator(tvm.relay.ExprMutator):
            def visit_call(self, call):
                op_name = call.op.name
                if op_name in annotator.op_list:
                    new_args = []
                    for arg in call.args:
                        ann = compiler_begin(super().visit(arg), annotator.compiler)
                        new_args.append(ann)
                    new_call = relay.Call(call.op, new_args, call.attrs, call.type_args)
                    return compiler_end(new_call, annotator.compiler)
                else:
                    return super().visit_call(call)

        return Annotator().visit(func)


In [5]:
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = AllowedListAnnotator(op_list, "test_compiler")(mod)

opt_pass = tvm.transform.Sequential(
    [
        relay.transform.InferType(),
        relay.transform.PartitionGraph(),
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.AlterOpLayout(),
    ]
)

with tvm.transform.PassContext(opt_level=3):
    mod = opt_pass(mod)
mod.show()

In [6]:
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
bn_mmean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
bn_mvar = relay.var("bn_var", relay.TensorType((16,), "float32"))

bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)

func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn_output.astuple())
mod = tvm.IRModule()
mod["main"] = func
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = AllowedListAnnotator(op_list, "test_compiler")(mod)

opt_pass = tvm.transform.Sequential(
    [
        relay.transform.InferType(),
        relay.transform.PartitionGraph(),
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.AlterOpLayout(),
        relay.transform.Inline(),
    ]
)

with tvm.transform.PassContext(opt_level=3):
    mod = opt_pass(mod)

mod.show()

## 注解常量折叠

In [8]:
from tvm.relay.build_module import bind_params_by_name

In [10]:
ones = np.ones(shape=(8, 8), dtype="float32")
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
f = relay.Function([x, y], log)
f = bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule()
mod["main"] = f
mod = AllowedListAnnotator(["add"], "ccompiler")(mod)
mod = relay.transform.PartitionGraph()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
y_data = np.random.rand(8, 8).astype("float32")
np_add = ones + y_data
check_result(mod, {"y": y_data}, (8, 8), np.log(np_add))

## 多输出

In [11]:
def create_graph():
    data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
    weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
    bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
    bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
    bn_mean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
    bn_var = relay.var("bn_var", relay.TensorType((16,), "float32"))

    data_cb = compiler_begin(data, "test_target")
    weight_cb = compiler_begin(weight, "test_target")
    bn_gamma_cb = compiler_begin(bn_gamma, "test_target")
    bn_beta_cb = compiler_begin(bn_beta, "test_target")
    bn_mean_cb = compiler_begin(bn_mean, "test_target")
    bn_var_cb = compiler_begin(bn_var, "test_target")

    conv_o = relay.nn.conv2d(
        data=data_cb, weight=weight_cb, kernel_size=(3, 3), channels=16, padding=(1, 1)
    )

    bn_o = relay.nn.batch_norm(conv_o, bn_gamma_cb, bn_beta_cb, bn_mean_cb, bn_var_cb)

    relu_o = relay.nn.relu(bn_o[0])
    relu_o_ce = compiler_end(relu_o, "test_target")

    bn_omean = bn_o[1]
    rebn_omean_ce = compiler_end(bn_omean, "test_target")
    bn_ovar = bn_o[2]
    bn_ovar_ce = compiler_end(bn_ovar, "test_target")

    dummy_mean_abs = relay.abs(rebn_omean_ce)
    dummy_ovar_abs = relay.abs(bn_ovar_ce)
    dummy_tuple = relay.Tuple((relu_o_ce, dummy_mean_abs, dummy_ovar_abs))

    func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], dummy_tuple)
    return func


In [13]:
mod = tvm.IRModule()
mod["main"] = create_graph()
partitioned = relay.transform.PartitionGraph()(mod)
partitioned.show()