#!/usr/bin/env python
import copy
from collections import OrderedDict

from hailo_sdk_client.exposed_definitions import JoinAction, JoinOutputLayersOrder, States
from hailo_sdk_client.runner.exceptions import UnsupportedRunnerJoinException
from hailo_sdk_client.sdk_backend.script_parser.commands import SupportedCommands
from hailo_sdk_client.sdk_backend.script_parser.model_script_parser import ModelScriptParser
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN, NetParams
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import Layer
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import ModelParams


class HNMergeError(Exception):
    pass


class ModelScriptMergeError(Exception):
    pass


def get_new_layer_name(layer, scope_names_mapping):
    if isinstance(layer, Layer):
        old_scope_name, name_without_scope = layer.scope, layer.name_without_scope
    else:
        old_scope_name, name_without_scope = layer.split("/", 1)

    return f"{scope_names_mapping[old_scope_name]}/{name_without_scope}"


def merge_hns(hn1, hn2, scope_names1, scope_names2, join_action=JoinAction.NONE, join_action_info=None):
    hn1_inputs = hn1.get_input_layers()
    hn2_inputs = hn2.get_input_layers()
    hn1_outputs = hn1.get_output_layers()

    join_action_info, join_inputs, chain_networks, output_layers_order = handle_join_action_info(
        join_action,
        join_action_info,
        scope_names1,
        scope_names2,
        hn1_inputs,
        hn2_inputs,
        hn1_outputs,
    )
    if join_action_info:
        join_action_info_reverse = {v: k for k, v in join_action_info.items()}

    if join_action_info and not chain_networks and not join_inputs:
        join_inputs, chain_networks = validate_join_action_info(join_action_info, hn1, hn2, scope_names1, scope_names2)

    new_hn_name = "joined_" + "_".join(scope_names1.values()) + "_" + "_".join(scope_names2.values())
    net_params = {
        "version": hn1.net_params.version.value,
        "stage": hn1.net_params.stage.value,
        "net_scopes": list(scope_names1.values()) + list(scope_names2.values()),
    }
    new_hailo_nn = HailoNN(network_name=new_hn_name, stage=hn1.net_params.stage.value)

    for layer in hn1.nodes:
        layer_name = get_new_layer_name(layer, scope_names1)
        if chain_networks and layer.op == LayerType.output_layer and layer_name in join_action_info:
            continue
        new_layer = copy.deepcopy(layer)
        new_layer.name = get_new_layer_name(layer, scope_names1)
        new_layer.inputs = [get_new_layer_name(inp, scope_names1) for inp in layer.inputs]
        new_layer.outputs = [get_new_layer_name(out, scope_names1) for out in layer.outputs]
        new_hailo_nn.add_node(new_layer)
    for layer in hn2.nodes:
        new_layer = copy.deepcopy(layer)
        new_layer.name = get_new_layer_name(layer, scope_names2)
        new_layer.outputs = [get_new_layer_name(out, scope_names2) for out in layer.outputs]
        if join_inputs:
            new_layer.inputs = []
            is_connected_to_hn1_output = False
            hn1_output_layer_name = join_action_info_reverse.get(new_layer.name)
            if hn1_output_layer_name is not None:
                out_scope, out_name = hn1_output_layer_name.split("/", 1)
                original_scope = next(key for key in scope_names1 if scope_names1[key] == out_scope)
                is_connected_to_hn1_output = (
                    hn1.get_layer_by_name(f"{original_scope}/{out_name}").op == LayerType.output_layer
                    if hn1_output_layer_name
                    else False
                )  # connecting output layer to input layer (chain) should be skipped when joining inputs
            for inp in layer.inputs:
                inp = get_new_layer_name(inp, scope_names2)
                inp_from_join_action = join_action_info_reverse.get(inp, inp)
                out_scope, out_name = inp_from_join_action.split("/", 1)
                if out_scope in scope_names1.values():
                    original_scope = next(key for key in scope_names1 if scope_names1[key] == out_scope)
                    hn1_layer_name = f"{original_scope}/{out_name}"
                else:
                    hn1_layer_name = inp_from_join_action
                is_inp_connected_to_hn1_output = (
                    hn1.get_layer_by_name(hn1_layer_name).op == LayerType.output_layer
                    if inp_from_join_action != inp
                    else False
                )
                new_layer.inputs.append(inp if is_inp_connected_to_hn1_output else inp_from_join_action)
            if (
                layer.op == LayerType.input_layer
                and new_layer.name in join_action_info_reverse
                and not is_connected_to_hn1_output
            ):
                inp_layer = new_hailo_nn.get_layer_by_name(join_action_info_reverse[new_layer.name])
                inp_layer.outputs = inp_layer.outputs + new_layer.outputs
                inp_layer.output_shapes = inp_layer.output_shapes + layer.output_shapes
            elif not is_connected_to_hn1_output:  # add new layer only if it is not connected to hn1 output layer
                new_hailo_nn.add_node(new_layer)
        else:
            new_layer.inputs = [get_new_layer_name(inp, scope_names2) for inp in layer.inputs]
            if not (
                chain_networks and layer.op == LayerType.input_layer and new_layer.name in join_action_info_reverse
            ):
                new_hailo_nn.add_node(new_layer)

    new_out_layers_order = get_new_out_layers_order(
        hn1,
        hn2,
        scope_names1,
        scope_names2,
        output_layers_order,
        chain_networks,
        join_action_info,
    )
    net_params["output_layers_order"] = new_out_layers_order
    new_hailo_nn.net_params = NetParams(net_params)

    if chain_networks:
        handle_network_chain(join_action_info_reverse, hn1, scope_names1, new_hailo_nn)

    for layer in new_hailo_nn.nodes:
        input_list = []
        for inp in layer.inputs:
            inp_layer = new_hailo_nn.get_layer_by_name(inp)
            new_hailo_nn.add_edge(inp_layer, layer)
            input_list.append(inp_layer)
        if hasattr(layer, "input_list"):
            layer.input_list = input_list
        if len(input_list) == 2 and layer.op in [LayerType.conv, LayerType.dw, LayerType.normalization]:
            layer.clear_ew_connections()
            layer.add_ew_connection(input_list[1])

    update_io_indices(new_hailo_nn)
    return new_hailo_nn


def get_new_out_layers_order(hn1, hn2, scopes1, scopes2, layers_order, chain_networks, join_info):
    new_order = []
    if hn1.net_params.output_layers_order and hn2.net_params.output_layers_order:
        # remove layers from hn1 which are no longer outputs
        still_outputs = hn1.net_params.output_layers_order.copy()
        first_dropped_idx = len(hn1.net_params.output_layers_order)
        is_first = True
        if chain_networks:
            for layer_name in hn1.net_params.output_layers_order:
                real_out_layer = hn1.get_layer_by_name(layer_name)
                if any(get_new_layer_name(out, scopes1) in join_info for out in real_out_layer.outputs):
                    if is_first:
                        first_dropped_idx = hn1.net_params.output_layers_order.index(layer_name)
                        is_first = False
                    still_outputs.remove(layer_name)

        hn1_output_layers = [get_new_layer_name(layer, scopes1) for layer in still_outputs]
        hn2_output_layers = [get_new_layer_name(layer, scopes2) for layer in hn2.net_params.output_layers_order]

        if layers_order == JoinOutputLayersOrder.NEW_OUTPUTS_FIRST:
            default_logger().info(
                "Order outputs by joined network outputs are first and original network outputs are second",
            )
            new_order = hn2_output_layers + hn1_output_layers

        elif layers_order == JoinOutputLayersOrder.NEW_OUTPUTS_IN_PLACE:
            default_logger().info(
                "Order outputs by original network outputs, then all joined network "
                "outputs where the first original output is missing and rest of "
                "original network outputs later",
            )
            new_order = (
                hn1_output_layers[:first_dropped_idx] + hn2_output_layers + hn1_output_layers[first_dropped_idx:]
            )

        elif layers_order == JoinOutputLayersOrder.NEW_OUTPUTS_LAST:
            default_logger().info(
                "Order outputs by original network outputs are first and joined network outputs are second",
            )
            new_order = hn1_output_layers + hn2_output_layers
    return new_order


def merge_params(npz1, npz2, params_kind, scope_names1, scope_names2):
    new_params = OrderedDict()
    new_params["params_kind"] = [ModelParams.PARAMS_KIND_TO_VALUE_DICT[params_kind]]
    for key, value in npz1.items():
        if key != "params_kind":
            new_params[get_new_layer_name(key, scope_names1)] = value

    for key, value in npz2.items():
        if key != "params_kind":
            new_params[get_new_layer_name(key, scope_names2)] = value

    return new_params


def handle_model_script_join(
    hn1,
    script1,
    scope_names1,
    hn2,
    script2,
    scope_names2,
    state,
    join_action=JoinAction.NONE,
):
    model1_global_commands, model1_ignored_commands_str, script_parser1 = prepare_script_to_join(
        hn1,
        scope_names1,
        script1,
        state,
    )
    model2_global_commands, model2_ignored_commands_str, script_parser2 = prepare_script_to_join(
        hn2,
        scope_names2,
        script2,
        state,
    )

    conflicted_commands = []
    validate_global_model_modification_commands(
        model1_global_commands,
        model2_global_commands,
        join_action,
        conflicted_commands,
        script_parser2,
    )
    for cmd1 in model1_global_commands:
        matches_from2 = [cmd2 for cmd2 in model2_global_commands if cmd2.function_name == cmd1.function_name]
        if matches_from2:
            cmd2 = matches_from2[0]
            if cmd1.function_name in [
                SupportedCommands.STRATEGY,
                SupportedCommands.PRE_QUANTIZATION_OPTIMIZATION,
                SupportedCommands.POST_QUANTIZATION_OPTIMIZATION,
            ] and str(cmd1) != str(cmd2):
                conflicted_commands.append((cmd1, cmd2))
            elif cmd2 in script_parser2.commands:
                script_parser2.commands.remove(cmd2)
    conflicted_commands.extend(
        [
            cmd
            for cmd in model1_global_commands
            if all(x.function_name != cmd.function_name for x in model2_global_commands)
        ],
    )
    conflicted_commands.extend(
        [
            cmd
            for cmd in model2_global_commands
            if all(x.function_name != cmd.function_name for x in model1_global_commands)
        ],
    )
    if conflicted_commands:
        err_msg = "Merging model scripts failed. The errors found:"
        for err in conflicted_commands:
            if isinstance(err, tuple):
                err_msg += f"\nParameter {err[0].function_name.value} found in both scripts does not match"
            else:
                err_msg += f"\nParameter {err.function_name.value} was only found in one of the scripts"
        raise ModelScriptMergeError(err_msg)

    model_script1 = create_modified_script(hn1, model1_ignored_commands_str, script_parser1)
    model_script2 = create_modified_script(hn2, model2_ignored_commands_str, script_parser2)

    return model_script1, model_script2


def validate_global_model_modification_commands(global_cmd1, global_cmd2, join_action, conflicted, script_parser2):
    modification_command_types = [
        SupportedCommands.NORMALIZATION,
        SupportedCommands.RESIZE,
        SupportedCommands.INPUT_CONVERSION,
        SupportedCommands.TRANSPOSE,
    ]
    global_modification1 = [cmd1 for cmd1 in global_cmd1 if cmd1.function_name in modification_command_types]
    global_modification2 = [cmd2 for cmd2 in global_cmd2 if cmd2.function_name in modification_command_types]

    for cmd2 in global_modification2:
        if cmd2.function_name not in [cmd.function_name for cmd in global_modification1]:
            # cmd only in model 2
            conflicted.append(cmd2)
            global_cmd2.remove(cmd2)

    for cmd1 in global_modification1:
        if cmd1.function_name not in [cmd.function_name for cmd in global_modification2]:
            if cmd1.function_name == SupportedCommands.TRANSPOSE:
                # transpose only in model 1
                conflicted.append(cmd1)
                global_cmd1.remove(cmd1)
            elif join_action != JoinAction.AUTO_CHAIN_NETWORKS:
                # cmd (not transpose) only in model 1, join action is None or inputs
                conflicted.append(cmd1)
                global_cmd1.remove(cmd1)
        elif join_action == JoinAction.AUTO_CHAIN_NETWORKS:
            # cmd in both models, join action is chain
            conflicted.append(cmd1)
            global_cmd1.remove(cmd1)
            global_cmd2.remove(cmd2)
        else:
            match_from2 = [cmd2 for cmd2 in global_modification2 if cmd2.function_name == cmd1.function_name]
            if cmd1.function_name in [SupportedCommands.NORMALIZATION, SupportedCommands.RESIZE]:
                for cmd2 in match_from2:
                    if str(cmd1) != str(cmd2):
                        # normalization \ resize in both models but doesn't match
                        conflicted.append(cmd1)
                        global_cmd1.remove(cmd1)
                        global_cmd2.remove(cmd2)
            else:
                # cmd in both models and matches
                script_parser2.commands.remove(cmd2)


def create_modified_script(hn, ignored_commands_str, script_parser):
    model_script1 = str(script_parser)
    if ignored_commands_str:
        model_script1 += f"# {hn.name}\n" + ignored_commands_str
    return model_script1


def prepare_script_to_join(hn, scope_names, script, state):
    script_parser = ModelScriptParser(hn)
    if script:
        script_parser.sorting_disabled = True
        script_parser.parse_script(script)

    if scope_names:
        script_parser.add_scope_to_commands(scope_names)

    global_commands, global_commands_to_ignore = get_global_commands(script_parser.commands, state)

    ignored_commands_str = ""
    for command in global_commands_to_ignore:
        script_parser.commands.remove(command)
        ignored_commands_str += f"# {command!s}\n"

    return global_commands, ignored_commands_str, script_parser


def get_global_commands(commands, state):
    global_allocation_command_types = [
        SupportedCommands.STRATEGY,
        SupportedCommands.PRINT_BUFFERS,
        SupportedCommands.OPTIMIZE_BUFFERS,
    ]
    global_quantization_command_types = [
        SupportedCommands.PRE_QUANTIZATION_OPTIMIZATION,
        SupportedCommands.POST_QUANTIZATION_OPTIMIZATION,
        SupportedCommands.MODEL_OPTIMIZATION_CONFIG,
    ]
    global_modification_command_types = [
        SupportedCommands.NORMALIZATION,
        SupportedCommands.RESIZE,
        SupportedCommands.INPUT_CONVERSION,
        SupportedCommands.TRANSPOSE,
    ]
    global_commands = []
    global_commands_to_ignore = []
    for cmd in commands:
        if cmd.function_name in global_allocation_command_types:
            global_commands.append(cmd)
        elif (
            cmd.function_name in global_quantization_command_types + global_modification_command_types
            and not cmd.has_layers()
        ):
            if state in [States.QUANTIZED_MODEL, States.QUANTIZED_SLIM_MODEL]:
                global_commands_to_ignore.append(cmd)
            else:
                global_commands.append(cmd)

    return global_commands, global_commands_to_ignore


def handle_join_action_info(
    join_action,
    join_action_info,
    scope_names1,
    scope_names2,
    hn1_inputs,
    hn2_inputs,
    hn1_outputs,
):
    join_inputs = False
    chain_networks = False
    output_layers_order = JoinOutputLayersOrder.NEW_OUTPUTS_LAST
    if join_action_info:
        if not isinstance(join_action_info, dict):
            raise HNMergeError(
                "join_action_info should be a dictionary where the keys are the input/output layer and "
                "the values are the corresponding input layer in the other network",
            )
        if join_action != JoinAction.CUSTOM:
            raise HNMergeError("join_action_info should be specified only if join action is custom")
        if "output_layers_order" in join_action_info:
            output_layers_order = join_action_info.pop("output_layers_order")
    if join_action == JoinAction.AUTO_JOIN_INPUTS:
        if not len(hn1_inputs) == len(hn2_inputs) == 1:
            raise HNMergeError("Auto input joining is only supported when both models have a single input")
        elif hn1_inputs[0].input_shape != hn2_inputs[0].input_shape:
            raise HNMergeError("Auto input joining is only supported when both model inputs have the same shape")
        elif hn1_inputs[0].transposed != hn2_inputs[0].transposed:
            raise HNMergeError(
                "Auto input joining is only supported when both model inputs are transposed or both are not transposed",
            )
        hn1_input = get_new_layer_name(hn1_inputs[0], scope_names1)
        hn2_input = get_new_layer_name(hn2_inputs[0], scope_names2)
        join_action_info = {hn1_input: hn2_input}
        join_inputs = True
    elif join_action == JoinAction.AUTO_CHAIN_NETWORKS:
        if not len(hn1_outputs) == len(hn2_inputs) == 1:
            raise HNMergeError(
                "Auto chaining networks is only supported when the first network has a single output, "
                "and the second network has a single input",
            )
        elif hn1_outputs[0].input_shape != hn2_inputs[0].input_shape:
            raise HNMergeError(
                "Auto chaining networks is only supported when the first network's output shape is equal"
                " to the second networks input shape",
            )
        elif hn1_outputs[0].transposed != hn2_inputs[0].transposed:
            raise HNMergeError(
                "Auto chaining networks is only supported when the first network's output is transposed"
                " when the second networks input is transposed or both are not transposed",
            )
        hn1_output = get_new_layer_name(hn1_outputs[0], scope_names1)
        hn2_input = get_new_layer_name(hn2_inputs[0], scope_names2)
        join_action_info = {hn1_output: hn2_input}
        chain_networks = True

    return join_action_info, join_inputs, chain_networks, output_layers_order


def validate_join_action_info(join_action_info, hn1, hn2, scope_names1, scope_names2):
    join_inputs = False
    chain_networks = False
    for inp in join_action_info:
        inp_layer_name_split = inp.split("/", 1)
        if len(inp_layer_name_split) != 2:
            raise HNMergeError(
                f"Layer {inp} given as a key in join_action_info should be in the form of <scope>/<layer>",
            )
        if inp_layer_name_split[0] not in scope_names1.values():
            raise HNMergeError(
                f"Layer {inp} given as a key in join_action_info does not belong to any of the runner scopes",
            )
        original_scope = next(key for key in scope_names1 if scope_names1[key] == inp_layer_name_split[0])
        inp_layer = hn1.get_layer_by_name(f"{original_scope}/{inp_layer_name_split[1]}")
        if inp_layer.op == LayerType.input_layer:
            join_inputs = True
        elif inp_layer.op == LayerType.output_layer:
            chain_networks = True
        if inp_layer.op not in [LayerType.input_layer, LayerType.output_layer]:
            raise HNMergeError(f"Layer {inp} is not an input layer or an output layer")
        out_layer_name_split = join_action_info[inp].split("/", 1)
        if len(out_layer_name_split) != 2:
            raise HNMergeError(
                f"Layer {join_action_info[inp]} given as a value in join_action_info should be in the "
                f"form of <scope>/<layer>",
            )
        if out_layer_name_split[0] not in scope_names2.values():
            raise HNMergeError(
                f"Layer {join_action_info[inp]} given as a value in join_action_info does not belong to "
                f"any of the parameter runner scopes",
            )
        original_scope = next(key for key in scope_names2 if scope_names2[key] == out_layer_name_split[0])
        out_layer = hn2.get_layer_by_name(f"{original_scope}/{out_layer_name_split[1]}")
        if out_layer.op != LayerType.input_layer:
            raise HNMergeError(
                f"Layer {join_action_info[inp]} is not an input layer, but all values in "
                f"join_action_info must be input layers",
            )
        if out_layer.output_shape != inp_layer.output_shape:
            raise HNMergeError(f"Layer {inp} does not have the same shape as layer {join_action_info[inp]}")
        if inp_layer.transposed != out_layer.transposed:
            raise HNMergeError(f"Layers {inp} and {join_action_info[inp]} must be both transposed or not")
    return join_inputs, chain_networks


def handle_network_chain(join_action_info_reverse, hn1, scope_names1, new_hailo_nn):
    hn1_output_name_to_real_output = {}
    join_action_info_reverse_no_hn1_inputs = join_action_info_reverse.copy()
    for key, out in join_action_info_reverse.items():
        out_scope, out_name = out.split("/", 1)
        original_scope = next(key for key in scope_names1 if scope_names1[key] == out_scope)
        hn1_out_layer = hn1.get_layer_by_name(f"{original_scope}/{out_name}")
        if hn1_out_layer.op == LayerType.input_layer:
            join_action_info_reverse_no_hn1_inputs.pop(key)
            continue  # remove input layers from join_action_info_reverse as they are handled in join_inputs scope
        layer_name = get_new_layer_name(hn1_out_layer.inputs[0], scope_names1)
        hn1_output_name_to_real_output[out] = new_hailo_nn.get_layer_by_name(layer_name)
    hn2_inputs = [
        node
        for node in new_hailo_nn.nodes
        if any(inp_name in node.inputs for inp_name in join_action_info_reverse_no_hn1_inputs)
    ]
    for hn2_input in hn2_inputs:
        new_inputs = []
        for inp in hn2_input.inputs:
            if inp in join_action_info_reverse_no_hn1_inputs:
                new_real_out_layer = hn1_output_name_to_real_output[join_action_info_reverse_no_hn1_inputs[inp]]
                new_real_out_layer.outputs = [
                    hn2_input.name if out == join_action_info_reverse_no_hn1_inputs[inp] else out
                    for out in new_real_out_layer.outputs
                ]
                new_inputs.append(new_real_out_layer.name)
            else:
                new_inputs.append(inp)
        hn2_input.inputs = new_inputs


def update_io_indices(hailo_nn):
    for layer in hailo_nn.stable_toposort():
        layer.input_indices = [hailo_nn.get_layer_by_name(inp).index for inp in layer.inputs]
        layer.output_indices = [hailo_nn.get_layer_by_name(out).index for out in layer.outputs]


def get_valid_scope_names(scope_names, net_scopes):
    if not scope_names:
        return None

    if not isinstance(scope_names, dict):
        if len(net_scopes) == 1:
            scope_names = {net_scopes[0]: scope_names}
        else:
            raise UnsupportedRunnerJoinException(
                f"scope name {scope_names} was given for the current runner, but multiple "
                f"scope names {net_scopes} already exist. Please use dictionary as explained "
                "in the documentation ",
            )
    else:
        invalid_scopes = [scope for scope in scope_names if scope not in net_scopes]
        if invalid_scopes:
            raise UnsupportedRunnerJoinException(
                f"{invalid_scopes} scope names were given for the current runner, but are not in "
                f"the current scope names {net_scopes}",
            )

    return {
        current_scope: HailoNN.get_valid_input_identifier(new_scope, "scope_name")
        for current_scope, new_scope in scope_names.items()
    }


def get_runner_scopes(new_scope_names, existing_net_scopes):
    valid_scope_names = get_valid_scope_names(new_scope_names, existing_net_scopes)
    runner_scopes = {scope: scope for scope in existing_net_scopes}
    if valid_scope_names:
        runner_scopes.update(valid_scope_names)
    return runner_scopes
