
from pathlib import Path
import numpy as np
import tvm
from tvm.relay.backend import te_compiler
from tvm.relay.backend.runtime import Runtime
import tvm.relay.testing
import tvm.relay.op as reg
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import byoc
from tvm.contrib import utils
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.op.contrib.register import get_pattern_table
from tvm.relay.build_module import bind_params_by_name

def set_func_attr(func, compile_name, symbol_name):
    func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
    func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
    func = func.with_attr("Compiler", compile_name)
    func = func.with_attr("global_symbol", symbol_name)
    return func

def update_lib(lib, source_dir="/media/pc/data/lxw/ai/tvm"):
    kwargs = {
        "options" : [
            "-O2", "-std=c++17", 
    tmp_path = utils.tempdir()
    lib_name = "lib.so"
    lib_path = tmp_path.relpath(lib_name)
    lib.export_library(lib_path, fcompile=False, **kwargs)
    lib = tvm.runtime.load_module(lib_path)
    return lib

class MobileNetAnnotator(ExprMutator):
    Annotate mobilenet until global_avg_pool.

    def __init__(self, compiler):
        super(MobileNetAnnotator, self).__init__()
        self.compiler = compiler
        self.compiler_open = False

    def visit_call(self, call):

        if call.op.name == "nn.global_avg_pool2d":
            self.compiler_open = True
        compiler_open = self.compiler_open

        params = []
        for arg in call.args:
            param = super().visit(arg)
            if call.op.name == "nn.global_avg_pool2d":
                param = compiler_end(param, self.compiler)
            if compiler_open and isinstance(param, relay.expr.Var):
                param = compiler_begin(param, self.compiler)

        new_call = relay.Call(call.op, params, call.attrs)
        return new_call

def check_result(
    def check_vm_result():
        with tvm.transform.PassContext(opt_level=3):
            exe = relay.vm.compile(mod, target=target, params=params)
        code, lib = exe.save()
        lib = update_lib(lib)
        exe = tvm.runtime.vm.Executable.load_exec(code, lib)
        vm = tvm.runtime.vm.VirtualMachine(exe, device)
        outs = vm.run(**map_inputs)
        outs = outs if isinstance(outs, tvm.runtime.container.ADT) else [outs]
        results = result if isinstance(result, list) else [result]
        for out, ref in zip(outs, results):
            tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol)


def create_graph():
    data = relay.var("data", shape=(10, 10))

    cb_1 = compiler_begin(data, "test_target")
    O_1 = relay.abs(cb_1)
    ce_2 = compiler_end(O_1, "test_target")
    O_2 = relay.nn.relu(O_1)
    ce_3 = compiler_end(O_2, "test_target")

    X = relay.tanh(ce_2)

    cb_3 = compiler_begin(ce_3, "test_target")
    cb_4 = compiler_begin(X, "test_target")
    O_3 = relay.add(cb_3, cb_4)
    ce_4 = compiler_end(O_3, "test_target")

    func = relay.Function([data], ce_4)
    return func

mod = tvm.IRModule()
mod["main"] = create_graph()
mod = transform.InferType()(mod)

partitioned = transform.PartitionGraph()(mod)


def test_multiple_use_of_an_output():
    def expected_same_output_region():
        mod = tvm.IRModule()
        x = relay.var("x", shape=(8, 8))
        y = relay.var("y", shape=(8, 8))
        z = relay.var("z", shape=(8, 8))
        x0 = relay.var("x0", shape=(8, 8))
        y0 = relay.var("y0", shape=(8, 8))
        log = relay.log(x0)
        sub = x0 - y0
        mul = log * sub
        # The partitioned graph contains log, subtract, and multiply
        func = relay.Function([x0, y0], mul)
        func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
        glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
        mod[glb_0] = func
        mod = transform.InferType()(mod)

        add = x + y
        call = relay.Call(glb_0, [add, z])
        main = relay.Function([x, y, z], call)
        mod["main"] = main
        mod = transform.InferType()(mod)
        return mod

    def expected_different_output_region():
        mod = tvm.IRModule()
        x = relay.var("x", shape=(8, 8))
        y = relay.var("y", shape=(8, 8))
        z = relay.var("z", shape=(8, 8))

        # The partitioned graph contains log
        i0 = relay.var("i0", shape=(8, 8))
        log = relay.log(i0)
        func = relay.Function([i0], log)
        func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
        glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
        mod[glb_0] = func
        mod = transform.InferType()(mod)

        # The partitioned graph contains subtract
        x0 = relay.var("x0", shape=(8, 8))
        y0 = relay.var("y0", shape=(8, 8))
        sub = x0 - y0
        func = relay.Function([x0, y0], sub)
        func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_1")
        glb_1 = relay.GlobalVar("tvmgen_default_ccompiler_main_1")
        mod[glb_1] = func
        mod = transform.InferType()(mod)

        add = x + y
        call_log = relay.Call(glb_0, [add])
        call_sub = relay.Call(glb_1, [add, z])
        main = relay.Function([x, y, z], call_log * call_sub)
        mod["main"] = main
        mod = transform.InferType()(mod)
        return mod

    def get_mod():
        x = relay.var("x", shape=(8, 8))
        y = relay.var("y", shape=(8, 8))
        z = relay.var("z", shape=(8, 8))
        add = x + y
        sub = add - z
        log = relay.log(add)
        sub1 = log * sub
        f = relay.Function([x, y, z], sub1)
        mod = tvm.IRModule()
        mod["main"] = f
        return mod

    def test_same_output_region():
        mod = get_mod()
        mod = AllowedListAnnotator(["subtract", "log", "multiply"], "ccompiler")(mod)
        mod = transform.MergeCompilerRegions()(mod)
        mod = transform.PartitionGraph()(mod)

        expected_mod = expected_same_output_region()
        assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)

    def test_different_output_region():
        mod = get_mod()
        mod = AllowedListAnnotator(["subtract", "log"], "ccompiler")(mod)
        mod = transform.MergeCompilerRegions()(mod)
        mod = transform.PartitionGraph()(mod)

        expected_mod = expected_different_output_region()
        assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)

def test_duplicate_outputs():
    target = "test_duplicate_outputs"

    @tvm.ir.register_op_attr("abs", "target." + target)
    def abs(expr):  # pylint: disable=unused-variable
        return True

    def create_graph():
        data = relay.var("data", shape=(10, 10))
        x = relay.abs(data)
        out_1 = relay.nn.relu(x)
        out_2 = relay.tanh(x)
        out_3 = relay.log(x)
        out = relay.Tuple([out_1, out_2, out_3])
        func = relay.Function([data], out)
        return func

    def expected():
        mod = tvm.IRModule()

        # function 0
        f0_i0 = relay.var(target + "_0_i0", shape=(10, 10))
        f0_o0 = relay.abs(f0_i0)
        func0 = relay.Function([f0_i0], f0_o0)

        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
        func0 = func0.with_attr("Compiler", target)
        func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0")
        gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0")
        mod[gv0] = func0
        mod = transform.InferType()(mod)

        # body
        data = relay.var("data", shape=(10, 10))
        function_out = gv0(data)
        out_1 = relay.nn.relu(function_out)
        out_2 = relay.tanh(function_out)
        out_3 = relay.log(function_out)
        out = relay.Tuple([out_1, out_2, out_3])
        func = relay.Function([data], out)
        mod["main"] = func
        mod = transform.InferType()(mod)
        return mod

    mod = tvm.IRModule()
    mod["main"] = create_graph()

    seq = tvm.transform.Sequential(

    ref_mod = expected()
    partitioned = seq(mod)
    assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)

def test_constant_tuples():
    @tvm.ir.register_op_attr("qnn.concatenate", "target.const_tuples")
    def add(expr):  # pylint: disable=unused-variable
        return True

    def create_graph():
        a = relay.var("a", shape=(10, 10), dtype="uint8")
        b = relay.var("b", shape=(10, 10), dtype="uint8")
        a1 = relay.abs(a)

        zeroi = relay.const(1, "int32")
        zerof = relay.const(0, "float32")
        con = relay.qnn.op.concatenate(
            (a1, b),
            input_scales=(zerof, zerof),
            input_zero_points=(zeroi, zeroi),

        f = relay.Function([a, b], con)
        mod = tvm.IRModule.from_expr(f)
        mod = transform.InferType()(mod)
        return mod

    seq = tvm.transform.Sequential(

    partitioned = seq(create_graph())

    concat = partitioned["tvmgen_default_const_tuples_main_0"].body
    assert type(concat.args[1]) == relay.Tuple
    assert type(concat.args[2]) == relay.Tuple
    assert type(concat.args[3]) == relay.Constant
    assert type(concat.args[4]) == relay.Constant

def test_flatten_tuple_output():
    target = "test_flatten_tuple_output"

    @tvm.ir.register_op_attr("split", "target." + target)
    def split(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("abs", "target." + target)
    def abs(expr):  # pylint: disable=unused-variable
        return True

    def create_graph():
        a = relay.var("a", shape=(10, 10), dtype="uint8")

        a_split = relay.split(a, 2)
        a_split_0 = relay.TupleGetItem(a_split.astuple(), 0)
        a_split_0_abs = relay.abs(a_split_0)

        a_con = relay.concatenate(a_split, 0)
        a_split_0_relu = relay.nn.relu(a_split_0_abs)

        out = relay.Tuple((a_con, a_split_0_relu))
        f = relay.Function([a], out)
        mod = tvm.IRModule.from_expr(f)
        mod = transform.InferType()(mod)
        return mod

    def expected():
        mod = tvm.IRModule()

        # function 0
        f0_i0 = relay.var(target + "_0_i0", shape=(10, 10), dtype="uint8")
        a_split = relay.split(f0_i0, 2)
        a_split_0 = relay.TupleGetItem(a_split.astuple(), 0)
        a_split_1 = relay.TupleGetItem(a_split.astuple(), 1)
        a_split_abs_in = relay.TupleGetItem(a_split.astuple(), 0)
        abs = relay.abs(a_split_abs_in)
        tuple_out = relay.Tuple((a_split_0, a_split_1, abs))
        func0 = relay.Function([f0_i0], tuple_out)

        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
        func0 = func0.with_attr("Compiler", target)
        func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0")
        gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0")
        mod[gv0] = func0
        mod = transform.InferType()(mod)

        # body
        data = relay.var("a", shape=(10, 10), dtype="uint8")
        f_out = gv0(data)
        f_out_0 = relay.TupleGetItem(f_out, 0)
        f_out_1 = relay.TupleGetItem(f_out, 1)
        tuple = relay.Tuple((f_out_0, f_out_1))
        concat = relay.concatenate(tuple, 0)
        f_out_2 = relay.TupleGetItem(f_out, 2)
        relu = relay.nn.relu(f_out_2)
        ret_tuple = relay.Tuple((concat, relu))
        mod["main"] = relay.Function([data], ret_tuple)
        mod = transform.InferType()(mod)
        return mod

    seq = tvm.transform.Sequential(

    partitioned = seq(create_graph())
    partitioned = transform.InferType()(partitioned)
    expected_mod = transform.InferType()(expected())
    assert tvm.ir.structural_equal(partitioned, expected_mod, map_free_vars=True)

def test_tuple_output_exec():
    """Test C codegen and runtime for a subgraph with a tuple output"""
    a = relay.var("a", shape=(10, 10), dtype="float32")
    b = relay.var("b", shape=(10, 10), dtype="float32")
    ba = relay.annotation.compiler_begin(a, "ccompiler")
    bb = relay.annotation.compiler_begin(b, "ccompiler")
    add = relay.add(ba, bb)
    sub = relay.subtract(ba, bb)
    out = relay.Tuple((add, sub))
    eout = relay.annotation.compiler_end(out, "ccompiler")
    func = relay.Function([a, b], eout)

    mod = tvm.IRModule()
    mod["main"] = func
    mod = transform.InferType()(mod)
    mod = transform.PartitionGraph()(mod)

    a_data = np.random.rand(10, 10).astype("float32")
    b_data = np.random.rand(10, 10).astype("float32")

        {"a": a_data, "b": b_data},
        [(10, 10), (10, 10)],
        [(a_data + b_data), (a_data - b_data)],

def test_extern_opt():
    def Optimize(mod):
        return relay.transform.FoldConstant()(mod)

    tvm.register_func("relay.ext.test_target.optimize", Optimize)

    x = relay.var("x", shape=(2, 2))
    y0 = relay.var("y0", shape=(2, 2))
    y1 = relay.var("y1", shape=(2, 2))
    yy0 = relay.annotation.compiler_begin(y0, "test_target")
    yy1 = relay.annotation.compiler_begin(y1, "test_target")
    z = yy0 + yy1
    end = relay.annotation.compiler_end(z, "test_target")
    f = relay.Function([x, y0, y1], end * x)
    c = np.ones(shape=(2, 2), dtype="float32")
    f = bind_params_by_name(f, {"y0": tvm.nd.array(c), "y1": tvm.nd.array(c)})
    mod = tvm.IRModule()
    mod["main"] = f
    mod = transform.InferType()(mod)
    mod = transform.PartitionGraph()(mod)

        t0 = mod["tvmgen_default_test_target_main_0"]
        raise KeyError("test_target_main_0 not found")

    assert isinstance(t0.body, relay.Constant)
    expected = np.empty([2, 2])
    tvm.testing.assert_allclose(t0.body.data.numpy(), expected, rtol=1e-5, atol=1e-5)

def test_preserve_type_import():
    """Test to make sure type definition and imports are preserved during the BYOC pipeline."""
    from tvm.relay.prelude import Prelude, StaticTensorArrayOps

    def run(dtype, shape):
        mod = tvm.IRModule()
        p = Prelude(mod)
        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)

        tensor_array = p.get_global_var_static("tensor_array", dtype, shape)
        tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape)
        write = p.get_global_var_static("tensor_array_write", dtype, shape)
        gather = p.get_global_var_static("tensor_array_gather", dtype, shape)
        v = relay.var("v")
        indice = relay.var("indice")
        init_tensor_array = tensor_array(relay.const(3))
        tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
        tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
        tensor_array3 = write(tensor_array2, relay.const(2), tensor(v))
        out = gather(tensor_array3, indice)
        mod["main"] = relay.Function([v, indice], out)
        mod = transform.RemoveUnusedFunctions()(mod)
        mod = transform.PartitionGraph()(mod)

    run("float32", [2, 3])

def test_not_bind_constant():
    def get_net(prefix, data, out_channel):
        weight = relay.var(prefix + "weight")
        bn_gamma = relay.var(prefix + "bn_gamma")
        bn_beta = relay.var(prefix + "bn_beta")
        bn_mmean = relay.var(prefix + "bn_mean")
        bn_mvar = relay.var(prefix + "bn_var")

        layer = relay.nn.conv2d(
            data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)
        bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
        out = relay.nn.relu(bn_output[0])
        return relay.Function(relay.analysis.free_vars(out), out)

    def get_partitoned_mod(mod, params, pattern_table, bind_constants):
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = tvm.transform.Sequential(
        composite_partition = tvm.transform.Sequential(

        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
            return composite_partition(mod)

    data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
    net = get_net("block_", data, 8)
    mod, params = tvm.relay.testing.create_workload(net)

    mod = get_partitoned_mod(mod, params, get_pattern_table("dnnl"), bind_constants=True)
    len(mod["main"].body.args) == 1

    mod = get_partitoned_mod(mod, params, get_pattern_table("dnnl"), bind_constants=False)
    len(mod["main"].body.args) == 3