import copy

import numpy as np
from safetensors import safe_open

from hailo_model_optimization.acceleras.utils.acceleras_definitions import OpStates
from hailo_sdk_client.runner.exceptions import UnsupporteLoraAdapterException
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN, hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, PaddingType
from hailo_sdk_common.hailo_nn.hn_layers import (
    FusedConv2DLayer,
    FusedStandaloneEWAddLayer,
)
from hailo_sdk_common.logger.logger import default_logger


def duplicate_network_group(model: HailoNN, base_network_group: str, new_network_group: str):
    curr_idx = model.get_next_index()
    component = model.get_component_by_scope(base_network_group)
    new_nodes = []
    for node in component.stable_toposort():
        new_node = copy.deepcopy(node)
        new_node.name = f"{new_network_group}/{node.name_without_scope}"
        new_node.base_layer = node.name
        new_node.index = curr_idx
        model.add_node(new_node)
        new_nodes.append(new_node)
        curr_idx += 1
        if node.name in model.net_params.output_layers_order:
            model.net_params.output_layers_order.append(new_node.name)
    for new_node in new_nodes:
        new_inputs = []
        for pred_name in new_node.inputs:
            pred = model.get_layer_by_name(pred_name)
            new_inputs.append(model.get_layer_by_name(f"{new_network_group}/{pred.name_without_scope}"))
            model.add_edge(new_inputs[-1], new_node)
        new_node.inputs = [node.name for node in new_inputs]
        new_node.input_indices = [node.index for node in new_inputs]
        new_outputs = []
        for succ_name in new_node.outputs:
            succ = model.get_layer_by_name(succ_name)
            new_outputs.append(model.get_layer_by_name(f"{new_network_group}/{succ.name_without_scope}"))
            model.add_edge(new_node, new_outputs[-1])
        new_node.outputs = [node.name for node in new_outputs]
        new_node.output_indices = [node.index for node in new_outputs]

    model.net_params.net_scopes.append(new_network_group)
    return model


def create_new_lora_layers(
    model, params, lora_weights_mapping, lora_weights, scope_name=None, is_acceleras_params=False
):
    fuser_helper = FuserHelper(model)
    curr_idx = model.get_next_index()
    for orig_name, lora_block_info in lora_weights_mapping.items():
        lora_head = model.get_layer_by_original_name(orig_name, scope_name=scope_name)
        if lora_head is None:
            raise UnsupporteLoraAdapterException(
                f"Layer {orig_name} was not found in the model. "
                "This means that the original model and the LoRA weights do not match."
            )

        pred = next(model.predecessors(lora_head))
        succs = list(model.successors(lora_head))

        lora_up = FusedConv2DLayer()
        lora_up.index = curr_idx
        lora_up.name = f"{lora_head.name}_lora_up"
        lora_up.inputs = [pred.name]
        lora_up.input_indices = [pred.index]
        lora_up.strides = [1, 1, 1, 1]
        lora_up.dilations = [1, 1, 1, 1]
        lora_up.padding = PaddingType.valid
        lora_up.kernel_shape = [1, 1, *list(lora_block_info["lora_up"]["shape"])]
        lora_up.activation = ActivationType.linear
        lora_up.input_shapes = [pred.output_shape]
        lora_up.update_output_shapes()
        lora_block_info["lora_up"]["hn_name"] = lora_up.name
        lora_up_key = lora_block_info["lora_up"]["name"]
        lora_up_kernel = {
            hn_to_npz_key(lora_up.name, "kernel"): np.reshape(lora_weights[lora_up_key], lora_up.kernel_shape).astype(
                np.float32
            ),
            hn_to_npz_key(lora_up.name, "bias"): np.zeros(lora_up.kernel_shape[-1], dtype=np.float32),
        }
        params.update(lora_up_kernel)
        if is_acceleras_params:
            params[hn_to_npz_key(lora_up.name, "layer_supported_states")] = np.array([OpStates.FP.value])
            params[hn_to_npz_key(lora_up.name, "ignore_io_shapes_verification")] = np.array(False)
            params[hn_to_npz_key(lora_up.name, "act_fully_native")] = np.array(True)
        curr_idx += 1

        lora_down = FusedConv2DLayer()
        lora_down.index = curr_idx
        lora_down.name = f"{lora_head.name}_lora_down"
        lora_down.inputs = [lora_up.name]
        lora_down.input_indices = [lora_up.index]
        lora_down.strides = [1, 1, 1, 1]
        lora_down.dilations = [1, 1, 1, 1]
        lora_down.padding = PaddingType.valid
        lora_down.kernel_shape = [1, 1, *list(lora_block_info["lora_down"]["shape"])]
        lora_down.activation = ActivationType.linear
        lora_down.input_shapes = [lora_up.output_shape]
        lora_down.update_output_shapes()
        lora_block_info["lora_down"]["hn_name"] = lora_down.name
        lora_down_key = lora_block_info["lora_down"]["name"]
        lora_down_kernel = {
            hn_to_npz_key(lora_down.name, "kernel"): np.reshape(
                lora_weights[lora_down_key], lora_down.kernel_shape
            ).astype(np.float32),
            hn_to_npz_key(lora_down.name, "bias"): np.zeros(lora_down.kernel_shape[-1], dtype=np.float32),
        }
        params.update(lora_down_kernel)
        if is_acceleras_params:
            params[hn_to_npz_key(lora_down.name, "layer_supported_states")] = np.array([OpStates.FP.value])
            params[hn_to_npz_key(lora_down.name, "ignore_io_shapes_verification")] = np.array(False)
            params[hn_to_npz_key(lora_down.name, "act_fully_native")] = np.array(True)
        curr_idx += 1

        lora_add = FusedStandaloneEWAddLayer()
        lora_add.index = curr_idx
        lora_add.name = f"{lora_head.name}_lora_add"
        lora_add.inputs = [lora_head.name, lora_down.name]
        lora_add.append_to_input_list(lora_head)
        lora_add.append_to_input_list(lora_down)
        lora_add.input_indices = [lora_head.index, lora_down.index]
        lora_add.outputs = lora_head.outputs.copy()
        lora_add.output_indices = lora_head.output_indices.copy()
        lora_add.input_shapes = [lora_head.output_shape, lora_down.output_shape]
        lora_add.update_output_shapes()
        if is_acceleras_params:
            params[hn_to_npz_key(lora_add.name, "layer_supported_states")] = np.array([OpStates.FP.value])
            params[hn_to_npz_key(lora_add.name, "ignore_io_shapes_verification")] = np.array(False)
            params[hn_to_npz_key(lora_add.name, "act_fully_native")] = np.array(True)
        curr_idx += 1

        input_features = lora_head.input_shape[-1]
        lora_up_in_features = lora_up.kernel_shape[-2]
        lora_up_out_features = lora_up.kernel_shape[-1]
        lora_down_in_features = lora_down.kernel_shape[-2]
        lora_down_out_features = lora_down.kernel_shape[-1]
        if (
            input_features != lora_up_in_features
            or lora_up_out_features != lora_down_in_features
            or lora_down_out_features != input_features
        ):
            raise UnsupporteLoraAdapterException(
                f"Layer {orig_name} LoRA weights have incompatible shapes. "
                "This means that the original model and the LoRA weights do not match."
            )

        model.add_node(lora_up)
        model.add_node(lora_down)
        model.add_node(lora_add)
        model.add_edge(pred, lora_up)
        model.add_edge(lora_up, lora_down)
        model.add_edge(lora_down, lora_add)
        model.add_edge(lora_head, lora_add)

        pred.outputs = pred.outputs.copy() + [lora_up.name]
        pred.output_indices = pred.output_indices.copy() + [lora_up.index]
        pred.output_shapes = pred.output_shapes.copy() + [lora_up.input_shape]

        for succ in succs:
            fuser_helper.replace_pred(succ, lora_head, lora_add)
            fuser_helper.replace_succ(lora_head, succ, lora_add)

        if lora_head.name in model.net_params.output_layers_order:
            index = model.net_params.output_layers_order.index(lora_head.name)
            model.net_params.output_layers_order[index] = lora_add.name

    return model, params, lora_weights_mapping


def _fill_default_lora_weights(lora_weights_mapping, lora_weights):
    for lora_block_info in lora_weights_mapping.values():
        for matrix_info in lora_block_info.values():
            key = matrix_info["name"]
            shape = tuple(matrix_info["shape"])
            if key not in lora_weights.keys():
                lora_weights[key] = np.zeros(shape, dtype=np.float32)
    return lora_weights


def _validate_extra_lora_weights(lora_weights_mapping, lora_weights):
    expected_keys = {
        matrix_info["name"]
        for lora_block_info in lora_weights_mapping.values()
        for matrix_info in lora_block_info.values()
    }
    if len(set(lora_weights.keys()) - expected_keys) > 0:
        missing_keys = set(lora_weights.keys()) - expected_keys
        raise UnsupporteLoraAdapterException(
            f"LoRA layers {missing_keys} are missing in the provided weights mapping file or in the new adapter weights file. "
            "Please make sure each adapter matches the set of weights that were initialized in the Hailo LoRA model."
        )


def _validate_lora_weights_shape(lora_weights_mapping, lora_weights):
    for lora_block_info in lora_weights_mapping.values():
        for matrix_info in lora_block_info.values():
            key = matrix_info["name"]
            shape = tuple(matrix_info["shape"])
            if lora_weights[key].shape != shape:
                raise UnsupporteLoraAdapterException(
                    f"LoRA weights for layer {key} have incompatible shape. "
                    "This means that the original model and the LoRA weights do not match. "
                    f"Expected shape: {shape}, got shape: {lora_weights[key].shape}"
                )


def load_lora_weights(
    model, params, lora_layers_metadata, lora_weights_file, lora_adapter_name, is_acceleras_params=False, log_info=False
):
    valid_adapter_name = HailoNN.get_valid_input_identifier(lora_adapter_name, "lora_adapter_name")
    if valid_adapter_name in model.net_params.lora_adapters:
        raise UnsupporteLoraAdapterException(f"Adapter {lora_adapter_name} already exists in the model.")

    tensors = {}
    with safe_open(lora_weights_file, framework="np", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)

    tensors = _fill_default_lora_weights(lora_layers_metadata, tensors)

    # Currently we only support LoRA weights with the same structure as defined in the lora_layers_metadata.
    _validate_extra_lora_weights(lora_layers_metadata, tensors)
    _validate_lora_weights_shape(lora_layers_metadata, tensors)

    base_network_group = model.net_params.lora_adapters[0]
    model = duplicate_network_group(model, base_network_group, valid_adapter_name)

    # initialize the adapter layers in the new network group
    new_lora_metadata = None
    model, params, new_lora_metadata = create_new_lora_layers(
        model=model,
        params=params,
        lora_weights_mapping=lora_layers_metadata.copy(),
        lora_weights=tensors,
        scope_name=valid_adapter_name,
        is_acceleras_params=is_acceleras_params,
    )

    model.net_params.lora_adapters.append(valid_adapter_name)
    if log_info:
        default_logger().info(f"Added LoRA weights for new adapter: {valid_adapter_name}")

    return model, params, new_lora_metadata
