import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PreQuantizationDefuseType
from hailo_sdk_client.model_translator.graph_lookup import FwdChainNode, get_all_nodes_in_chain
from hailo_sdk_client.post_fuser.algorithms.exceptions import MHADefuseException
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hailo_nn import hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import ConcatLayer
from hailo_sdk_common.hailo_nn.hn_layers.feature_splitter import FeatureSplitterLayer


class MHADefuse(FuserAlgorithm):
    NAME = "mha_defuse"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(self.model)

    def get_algo_config(self):
        return self._model_config.defuse

    def _setup(self):
        pass

    def should_skip_algo(self):
        # TODO: get from config
        return False

    def log_config(self):
        pass

    def _run_int(self):
        self._split_mha_blocks()

    def _split_mha_blocks(self):
        new_layers = []
        successors_meta_data = {}
        algo_cfg = self.get_algo_config()
        for layer in list(self._model):
            if (
                layer.name not in algo_cfg.layers
                or algo_cfg.layers[layer.name].defuse_type != PreQuantizationDefuseType.MHA
            ):
                continue

            num_splits = algo_cfg.layers[layer.name].num_splits
            if num_splits is None or num_splits == 1:
                continue

            is_valid, result = self._is_valid_mha_block(layer)
            if is_valid:
                mha_block_layers = result
            else:
                error_reason = result
                raise MHADefuseException(
                    f"Invalid MHA block {layer.name}: {error_reason}.\n"
                    f"Expected MHA block represented by its first "
                    f"matmul with two inputs (Q, K) when matmul input transposed, followed by "
                    f"softmax and matmul, when the second matmul input are the softmax and V - "
                    f"matmul(softmax(matmul(Q, K_transpose)), V)",
                )

            if layer.groups % num_splits != 0:
                raise MHADefuseException(
                    f"Invalid command parameters - the number of groups ({layer.groups}) in the "
                    f"MHA block isn’t divisible by the requested number of splits ({num_splits})",
                )

            kq_matmul, softmax_block, k, q, v = mha_block_layers
            v_matmul = softmax_block[-1]
            softmax_block.pop(-1)
            self._logger.info(f"Splitting MHA block of matmul: {kq_matmul.name_without_scope} into {num_splits} splits")

            k_split = self._create_kqv_split(k, new_layers, num_splits)
            q_split = self._create_kqv_split(q, new_layers, num_splits)
            v_split = self._create_kqv_split(v, new_layers, num_splits)
            mask_split = None
            if len(softmax_block) > 1:
                ew_mult = softmax_block[2]
                mask = self.model.get_layer_by_name(ew_mult.inputs[1])
                mask_split = self._create_kqv_split(mask, new_layers, num_splits)

            kq_matmul_defuses = self._defuse_layer(kq_matmul, num_splits, new_layers)
            softmax_blocks = self._defuse_softmax_block(softmax_block, num_splits, new_layers)
            v_matmul_defuses = self._defuse_layer(v_matmul, num_splits, new_layers)

            concat = self._add_concat(v_matmul, new_layers)

            for succ in self.model.successors(v_matmul):
                self.model.remove_edge(v_matmul, succ)
                self.model.add_edge(concat, succ)
                succ.replace_input_layer(v_matmul.name, concat.name)
                succ.replace_input_index(v_matmul.index, concat.index)

            for i in range(num_splits):
                kq_matmul_di = kq_matmul_defuses[i]
                softmax_block_i = softmax_blocks[i]
                v_matmul_di = v_matmul_defuses[i]

                kq_matmul_di.inputs = [q_split.name, k_split.name]
                kq_matmul_di.input_indices = [q_split.index, k_split.index]
                kq_matmul_di.input_shapes = [q_split.output_shapes[i], k_split.output_shapes[i]]
                kq_matmul_di.output_shapes = [q_split.output_shapes[i][:-1] + [kq_matmul_di.kernel_shape[3]]]
                k_split.append_output_layer(kq_matmul_di.name)
                k_split.append_output_index(kq_matmul_di.index)
                q_split.append_output_layer(kq_matmul_di.name)
                q_split.append_output_index(kq_matmul_di.index)
                self.model.add_edge(k_split, kq_matmul_di)
                self.model.add_edge(q_split, kq_matmul_di)

                kq_matmul_di.outputs = [softmax_block_i[0].name]
                kq_matmul_di.output_indices = [softmax_block_i[0].index]
                softmax_block_i[0].inputs = [kq_matmul_di.name]
                softmax_block_i[0].input_indices = [kq_matmul_di.index]
                softmax_block_i[0].input_shapes = [kq_matmul_di.output_shape]
                self.model.add_edge(kq_matmul_di, softmax_block_i[0])

                self._attach_softmax_block(softmax_block_i, kq_matmul_di, mask_split)
                softmax_block_i[-1].outputs = [v_matmul_di.name]
                softmax_block_i[-1].output_indices = [v_matmul_di.index]
                softmax_block_i[-1].update_output_shapes()
                v_split.append_output_layer(v_matmul_di.name)
                v_split.append_output_index(v_matmul_di.index)
                v_matmul_di.inputs = [softmax_block_i[-1].name, v_split.name]
                v_matmul_di.input_indices = [softmax_block_i[-1].index, v_split.index]
                v_matmul_di.input_shapes = [softmax_block_i[-1].output_shape, v_split.output_shapes[i]]
                v_matmul_di.output_shapes = [softmax_block_i[-1].output_shape[:-1] + [v_matmul_di.kernel_shape[3]]]
                self.model.add_edge(softmax_block_i[-1], v_matmul_di)
                self.model.add_edge(v_split, v_matmul_di)

                v_matmul_di.outputs = [concat.name]
                v_matmul_di.output_indices = [concat.index]
                concat.append_input_layer(v_matmul_di.name)
                concat.append_input_index(v_matmul_di.index)
                concat.append_input_shapes([v_matmul_di.output_shape])
                self.model.add_edge(v_matmul_di, concat)

        for layer in new_layers:
            self._model.relax_new_layer_into_graph(layer, successors_meta_data)

    def _attach_softmax_block(self, block, kq_matmul, mask_split):
        for i, layer in enumerate(block[1:]):
            # masked softmax - reduce_max, ew_sub, ew_mult, reduce_sum, ew_mult.
            # i is always one smaller than the actual index because of enumerate(block[1:])
            pred = block[i]
            self.model.add_edge(pred, layer)
            pred.outputs = [layer.name]
            pred.output_indices = [layer.index]
            pred.update_output_shapes()
            layer.inputs = [block[i].name]
            layer.input_indices = [block[i].index]
            layer.input_shapes = [block[i].output_shape]
            if layer.op in [LayerType.ew_sub, LayerType.ew_mult]:
                if i == 0:  # ew_sub(qk_matmul, reduce_max)
                    inp0 = kq_matmul
                elif i == 1:  # ew_mult1(ew_sub, mask_split)
                    inp0 = mask_split
                elif i == 3:  # ew_mult2(ew_mult1, reduce_sum)
                    inp0 = block[2]
                inp0.append_output_layer(layer.name)
                inp0.append_output_index(layer.index)
                layer.inputs.insert(0, inp0.name)
                layer.input_indices.insert(0, inp0.index)
                layer.input_shapes.insert(0, inp0._get_output_shape(layer_name=layer.name, layer_index=layer.index))
                layer.input_list = [inp0, block[i]]
                self.model.add_edge(inp0, layer)

    def _is_valid_mha_block(self, layer):
        if layer.op != LayerType.matmul:
            return False, f"{layer.name} op is not matmul"

        if not layer.transpose_matmul_input:
            return False, f"{layer.name} input is not transposed"

        result = get_all_nodes_in_chain(
            self.model,
            layer,
            [FwdChainNode(op=LayerType.softmax), FwdChainNode(op=LayerType.matmul)],
            exact_match=True,
        )
        if result is None:
            result = FuserHelper.is_layer_in_masked_softmax_block(self.model, layer)
            if result is None:
                return False, f"Couldn't find softmax->matmul after {layer.name}"

        kq_matmul = layer
        softmax, v_matmul = result[0], result[-1]  # in case of masked softmax it's reduce max
        if v_matmul.transpose_matmul_input:
            return False, f"Second matmul {v_matmul.name} input is transposed"

        if not (kq_matmul.groups == softmax.groups == v_matmul.groups):
            return (
                False,
                f"{kq_matmul.name}, {softmax.name}, {v_matmul.name} should have the same groups number but "
                f"they have {kq_matmul.groups}, {softmax.groups}, {v_matmul.groups}",
            )

        k = self.model.get_layer_by_name(kq_matmul.inputs[1])
        q = self.model.get_layer_by_name(kq_matmul.inputs[0])
        v = self.model.get_layer_by_name(v_matmul.inputs[1])

        are_valid_heads, error_reason = self._are_valid_heads([k, q, v])
        if not are_valid_heads:
            return False, error_reason

        return True, [kq_matmul, result, k, q, v]

    def _defuse_softmax_block(self, block_to_defuse, num_splits, new_layers):
        block_defuses = [[] for _ in range(num_splits)]
        for layer in block_to_defuse:
            layer_defuses = self._defuse_layer(layer, num_splits, new_layers)
            for i, block in enumerate(block_defuses):
                block.append(layer_defuses[i])
        return block_defuses

    def _defuse_layer(self, layer_to_defuse, num_splits, new_layers):
        defuses = []
        new_params = {}
        if hasattr(layer_to_defuse, "groups"):
            layer_to_defuse._groups = layer_to_defuse.groups // num_splits
        if layer_to_defuse.op == LayerType.matmul:
            kernel_shape = layer_to_defuse.kernel_shape
            layer_to_defuse._kernel_shape = kernel_shape[:-1] + [kernel_shape[-1] // num_splits]
        elif layer_to_defuse.op == LayerType.softmax:
            additive_mask = self.params.get(hn_to_npz_key(layer_to_defuse.name, "additive_mask"))
            if additive_mask is not None:
                defused_masks = np.split(additive_mask, num_splits, axis=-1)
                new_params.update({hn_to_npz_key(f"{layer_to_defuse.name}_d0", "additive_mask"): defused_masks[0]})

        for i in range(1, num_splits):
            defused_layer = type(layer_to_defuse).from_layer(layer_to_defuse)
            defused_layer.name = f"{layer_to_defuse.name}_d{i}"
            defused_layer.index = self.model.get_next_index()
            defused_layer.move_params(layer_to_defuse)
            if new_params:
                # softmax with additive mask
                new_params.update({hn_to_npz_key(defused_layer.name, "additive_mask"): defused_masks[i]})
            self.model.add_node(defused_layer)
            new_layers.append(defused_layer)
            defuses.append(defused_layer)

        self._model_config.remove_layer_from_all_configs(layer_to_defuse.name)
        layer_to_defuse.name += "_d0"
        self.params.update(new_params)
        return [layer_to_defuse, *defuses]

    def _create_kqv_split(self, layer, new_layers, num_splits):
        index = self.model.get_next_index()
        f_out = layer.output_features // num_splits
        output_shapes = [[*layer.output_shape[:-1], f_out] for _ in range(num_splits)]
        feature_splitter = self._fuser_helper.create_layer(
            FeatureSplitterLayer,
            index,
            "feature_splitter",
            layer,
            new_layers,
            output_shapes,
        )
        feature_splitter.split_sizes = [f_out] * num_splits
        feature_splitter.inputs = [layer.name]
        feature_splitter.input_indices = [layer.index]
        feature_splitter.input_shapes = [layer.output_shape]
        layer.outputs = [feature_splitter.name]
        layer.output_indices = [feature_splitter.index]
        matmul = next(iter(self.model.successors(layer)))
        self.model.remove_edge(layer, matmul)
        self.model.add_edge(layer, feature_splitter)
        return feature_splitter

    def _add_concat(self, layer, new_layers):
        concat = ConcatLayer.from_layer(layer)
        block_name, layer_name = self.get_block_and_layer_names(layer.name_without_scope)
        concat.name = f"{layer.scope}/{block_name}concat_{layer_name}"
        concat.index = self.model.get_next_index()
        concat.move_params(layer)
        self.model.add_node(concat)
        new_layers.append(concat)

        for i, output in enumerate(self._model.net_params.output_layers_order):
            if layer.name == output:
                self.model.net_params.output_layers_order[i] = concat.name

        return concat

    def _are_valid_heads(self, heads):
        for head in heads:
            if len(list(self.model.successors(head))) != 1:
                return False, f"Head {head.name} must have exactly 1 successor"

        return True, ""
