import logging

import networkx as nx
import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_splitter import HailoFeatureSplitter
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerEqualizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    GatherTracker,
    SplitTracker,
)
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class ConvDecomposition(OptimizationAlgorithm):
    """
    UsePreQuantWeights class is an optimization algorithm that allows the user to quantize the weights of a model
    before the optimization process using a theird party tool. Quantized weights needs to be stored as the new native weights of the model (Quantized * scale).
    The algorithm will set the quantization mode of the layers to the desired mode and will set the scale by kernel appropriately.
    """

    def __init__(self, model: HailoModel, model_config, logger_level, dataset, **kwargs):
        super().__init__(model, model_config, "conv_decomposition", logger_level, **kwargs)
        self._dataset = dataset

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    def get_algo_config(self):
        return self._model_config.conv_decomposition

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        layer_to_decompose = list(self._model_config.conv_decomposition.layers.keys())
        return len(layer_to_decompose) == 0

    def _setup(self):
        if any([l_config.sort_channels_by_stats for l_config in self._model_config.conv_decomposition.layers.values()]):
            self._logger.debug("ConvDecompose: setup reordering weights")
            self.reorder_weights_by_act_stats()

    def _run_int(self):
        self._logger.debug("ConvDecomose: _run_int")

        for lname in list(self._model_config.conv_decomposition.layers.keys()):
            self._logger.debug(f"ConvDecomose: {lname}")
            self._conv_decompose(lname, self._model_config.conv_decomposition.layers[lname])

        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level,
            logger=self._logger,
        )
        algo.run()

    def _conv_decompose(self, conv_name, conv_decompose_cfg):
        conv = self._model.layers[conv_name]
        conv_decomp_flow, entry_node, exit_node = self._build_conv_decomposition_flow(conv, conv_decompose_cfg)
        self._add_input_and_output_to_flow(conv_decomp_flow, conv, entry_node, exit_node)
        self.add_flow_to_model(conv_decomp_flow, conv)

    def add_flow_to_model(self, decompose_conv_flow: nx.DiGraph, conv_layer: HailoConv):
        """
        adding block to the model
        1. remove the old layer from the model and model flow
        2. add the new layers to the model and model flow
        3. add the edges between the layers
        4. fix the output layer order if applies
        """
        # Remove old layer from model & model flow
        del self._model.layers[conv_layer.full_name]
        self._model.flow.remove_node(conv_layer.full_name)
        if self._model_config is not None:
            self._model_config.remove_layer_from_all_configs(conv_layer.full_name)

        # Add nodes to the model & model flow
        for lname in nx.topological_sort(decompose_conv_flow):
            self._logger.debug(f"Adding layer {lname} to model")
            layer = decompose_conv_flow.nodes[lname]["layer"]
            if layer is not None:
                if layer.full_name in self._model.layers or layer.full_name in self._model.flow.nodes:
                    raise ValueError(
                        f"{layer.full_name} already exists in model,conv {conv_layer.full_name} decomposition failed"
                    )
                self._model.layers[layer.full_name] = layer
                self._model.flow.add_node(layer.full_name)

        # Add edges to the model flow
        edges_graph = nx.line_graph(decompose_conv_flow)
        for u, v in nx.topological_sort(edges_graph):  # topological sort for edges
            edge_attr = decompose_conv_flow.edges[(u, v)]
            self._model.flow.add_edge(u, v, **edge_attr)

        # Fix output layer order if applies
        if conv_layer.full_name in self._model.flow.output_layer_order:
            out_nodes = [u for (u, _), d in edges_graph.out_degree if d == 0]
            if len(out_nodes) != 1:
                raise RuntimeError(f"Unexpected output nodes when decomposing conv {conv_layer.full_name} {out_nodes}")
            idx = self._model.flow.output_layer_order.index(conv_layer.full_name)
            self._model.flow.output_layer_order[idx] = out_nodes[0]

    def _build_conv_decomposition_flow(self, conv_layer: HailoConv, conv_decompose_cfg):
        fs, mini_convs, mini_ew_adds = self._create_layers(
            conv_layer,
            conv_decompose_cfg.sub_group_size,
            conv_decompose_cfg.allow_equlize_block,
        )

        self._update_precision_config(conv_layer, fs, mini_convs, mini_ew_adds, conv_decompose_cfg)

        entry_node = fs
        exit_node = mini_ew_adds[-1]
        # build flow
        conv_decomp_flow = nx.DiGraph()
        ### add all layers
        conv_decomp_flow.add_node(fs.full_name, layer=fs)
        for mini_conv in mini_convs:
            conv_decomp_flow.add_node(mini_conv.full_name, layer=mini_conv)
        for mini_ew_add in mini_ew_adds:
            conv_decomp_flow.add_node(mini_ew_add.full_name, layer=mini_ew_add)

        ### add edges fs -> mini_convs
        for i, mini_conv in enumerate(mini_convs):
            conv_decomp_flow.add_edge(fs.full_name, mini_conv.full_name, output_index=i)

        ### add edges mini_convs -> mini_ew_adds
        ellement_to_add = mini_convs
        add_ellements = mini_ew_adds

        while len(ellement_to_add) > 0:
            mini_conv_0 = ellement_to_add.pop(0)
            mini_conv_1 = ellement_to_add.pop(0)
            mini_ew_add = add_ellements.pop(0)
            conv_decomp_flow.add_edge(mini_conv_0.full_name, mini_ew_add.full_name, input_index=0)
            conv_decomp_flow.add_edge(mini_conv_1.full_name, mini_ew_add.full_name, input_index=1)
            if len(add_ellements) > 0:
                ellement_to_add.append(mini_ew_add)

        return conv_decomp_flow, entry_node, exit_node

    def _create_layers(self, conv_layer, subgroup_size, allow_equlize_block):
        conv_dict_params = conv_layer.export_weights()
        conv_type = HailoConv if isinstance(conv_layer, HailoConv) else BaseHailoConv
        no_subgroups = conv_layer.conv_op.kernel.shape[-2] // subgroup_size

        # create fs for input to mini convs
        fs = HailoFeatureSplitter(
            name=conv_layer.full_name + "_decompose_feature_splitter",
            split_sizes=[subgroup_size] * int(no_subgroups),
            groups=conv_layer.groups,
        )

        # create mini convs
        mini_convs = [
            conv_type(
                name=conv_layer.full_name + f"_subgroup_{i}",
                filters=conv_layer.conv_op.kernel.shape[-1],
                kernel_size=conv_layer.conv_op.kernel.shape[:-2],
                padding=conv_layer.conv_op.padding.value,
                stride_align=conv_layer.conv_op.stride_align.value,
                strides=conv_layer.conv_op.strides,
                groups=conv_layer.conv_op.groups,
                group_sizes=conv_layer.conv_op.group_sizes,
                activation=ActivationType.LINEAR.value,
                transpose_output_width_features=conv_layer.transpose_output_width_features,
                spatial_flatten_output=conv_layer.conv_op.spatial_flatten_output,
                dilation_rate=conv_layer.conv_op.dilation_rate,
                zp_comp_add=False,
                logger=conv_layer._logger,
            )
            for i in range(int(no_subgroups))
        ]

        self._modifications_meta_data.append(
            conv_layer.full_name,
            SplitTracker(
                mini_convs=[mini_conv.full_name for mini_conv in mini_convs],
                ew_add_after=True,
            ),
        )

        # set weights for mini convs
        conv_dict_params = conv_layer.export_weights()
        for group_i, mini_conv in enumerate(mini_convs):
            mini_conv_dict_params = conv_dict_params.copy()
            mini_conv_dict_params[f"{mini_conv.full_name}/kernel:0"] = conv_dict_params["kernel"].copy()[
                :, :, subgroup_size * group_i : subgroup_size * (group_i + 1), :
            ]
            mini_conv_dict_params[f"{mini_conv.full_name}/bias:0"] = conv_dict_params["bias"].copy() / no_subgroups
            mini_conv_dict_params["padding_const_value"] = conv_dict_params["padding_const_value"]
            mini_conv_lparams = LayerParams(mini_conv_dict_params, mini_conv.full_name)
            mini_conv.import_weights(mini_conv_lparams)

        # set_scale_by_kernel_only
        set_scale_by_kernel_only = conv_layer.conv_op.set_scale_by_kernel_only
        for mini_conv in mini_convs:
            mini_conv.conv_op.set_scale_by_kernel_only = set_scale_by_kernel_only

        # set precision_split_zp
        precision_split_zp = conv_layer.precision_split_zp
        for mini_conv in mini_convs:
            mini_conv.precision_split_zp = precision_split_zp
            if not allow_equlize_block:
                self._model_config.equalization.layers[mini_conv.full_name] = LayerEqualizationConfig(policy="disabled")

        # create mini elementwise adds
        num_ew_adds = no_subgroups - 1
        mini_ew_adds = [
            HailoElementwiseAdd(
                name=conv_layer.full_name + f"_ew_add_{i}", activation=ActivationType.LINEAR.value, input_repeats=None
            )
            for i in range(num_ew_adds - 1)
        ]
        mini_ew_adds.append(
            HailoElementwiseAdd(
                name=conv_layer.full_name + f"_ew_add_{num_ew_adds-1}",
                activation=conv_layer.act_op.act_name,
                input_repeats=None,
            )
        )

        if not allow_equlize_block:
            for mini_ew_add in mini_ew_adds:
                self._model_config.equalization.layers[mini_ew_add.full_name] = LayerEqualizationConfig(
                    policy="disabled"
                )

        return fs, mini_convs, mini_ew_adds

    def _update_precision_config(self, conv_layer, fs, mini_convs, mini_ew_adds, conv_decompose_cfg):
        self.set_precision_config(fs, precision_cfg=None, use_dfault=True)
        precision_cfg = self._model_config.precision_config.layers[conv_layer.full_name]

        if precision_cfg.precision_mode == PrecisionMode.a8_w4_a8:
            precision_cfg.precision_mode = PrecisionMode.a8_w4
        elif precision_cfg.precision_mode == PrecisionMode.a8_w8_a8:
            precision_cfg.precision_mode = PrecisionMode.a8_w8

        for mini_conv in mini_convs[:-1]:
            self.set_precision_config(mini_conv, precision_cfg=precision_cfg, use_dfault=False)

        if conv_decompose_cfg.sort_channels_by_stats:
            precision_cfg_new = mini_convs[-1].get_default_precision_config()
            precision_cfg_new.precision_mode = PrecisionMode.a8_w8
            precision_cfg_new.bias_mode = precision_cfg.bias_mode
        else:
            precision_cfg_new = precision_cfg

        self.set_precision_config(mini_convs[-1], precision_cfg=precision_cfg_new, use_dfault=False)

        if conv_layer.full_name in self._model_config.weights_clipping.layers:
            weights_clipping_cfg = self._model_config.weights_clipping.layers[conv_layer.full_name]
            for mini_conv in mini_convs:
                self._model_config.weights_clipping.layers[mini_conv.full_name] = weights_clipping_cfg

        for mini_ew_add in mini_ew_adds:
            precision_cfg = mini_ew_add.get_default_precision_config()
            precision_cfg.precision_mode = conv_decompose_cfg.pm_ew_adds
            self.set_precision_config(mini_ew_add, precision_cfg=precision_cfg, use_dfault=False)

    def _add_input_and_output_to_flow(self, conv_decomp_flow, conv_layer, entry_node, exit_node):
        # add input and output to flow
        pred_node = self._model.layers[self._model.flow.predecessors_sorted(conv_layer.full_name)[0]]
        conv_decomp_flow.add_node(pred_node.full_name, layer=None)
        output_index = self._model.flow.successors_sorted(pred_node.full_name).index(conv_layer.full_name)
        conv_decomp_flow.add_edge(pred_node.full_name, entry_node.full_name, input_index=0, output_index=output_index)

        succ_nodes_names = self._model.flow.successors_sorted(conv_layer.full_name)

        for i, succ_node_name in enumerate(succ_nodes_names):
            conv_decomp_flow.add_node(succ_node_name, layer=None)
            input_index = self._model.flow.predecessors_sorted(succ_node_name).index(conv_layer.full_name)
            conv_decomp_flow.add_edge(exit_node.full_name, succ_node_name, output_index=0, input_index=input_index)

    def set_precision_config(self, layer, precision_cfg, use_dfault=False):
        if use_dfault:
            precision_cfg = layer.get_default_precision_config()

        layer.verify_config(precision_cfg)
        layer.import_precision_config(precision_cfg, self.optimization_target)
        layer.hw_arch = self._model_config.precision_config.target
        self._model_config.precision_config.layers[layer.full_name] = precision_cfg

    def reorder_weights_by_act_stats(self):
        """
        this function reorders the weights of the conv layers based on the activation statistics
        in order to have high activation multipled by 8-bit weights
        there is a jira to generalize this function.
        https://hailotech.atlassian.net/browse/SDK-57193
        """
        scope = next(iter(self._model.layers.keys())).split("/")[0]
        if scope.startswith("qwen2"):
            block_num = int(scope.split("qwen2_block")[1])
            self._logger.debug(f"block_num: {block_num} re-orging weights")
            stats_collector = StatsCollector(
                self._model,
                self._model_config,
                logging.DEBUG,
                self._dataset,
                layers_to_handle=[f"{scope}/ew_mult7"],
                logger=self._logger,
            )
            stats_collector.run()

            ind_sorted = np.argsort(self._model.layers[f"{scope}/ew_mult7"].get_output_stats()[0].max)
            k5 = self._model.layers[f"{scope}/conv5"].export_weights()
            k6 = self._model.layers[f"{scope}/conv6"].export_weights()
            k7 = self._model.layers[f"{scope}/conv7"].export_weights()

            indices_key = self._modifications_meta_data.add_modification_param("indices", ind_sorted)
            self._modifications_meta_data.append(
                f"{scope}/conv5",
                GatherTracker(
                    indices_key=indices_key,
                    apply_on_input=False,
                ),
            )
            self._modifications_meta_data.append(
                f"{scope}/conv6",
                GatherTracker(
                    indices_key=indices_key,
                    apply_on_input=False,
                ),
            )
            self._modifications_meta_data.append(
                f"{scope}/conv7",
                GatherTracker(
                    indices_key=indices_key,
                    apply_on_input=True,
                ),
            )

            k5["kernel"] = k5["kernel"][:, :, :, ind_sorted]
            k6["kernel"] = k6["kernel"][:, :, :, ind_sorted]
            k5["bias"] = k5["bias"][ind_sorted]
            k6["bias"] = k6["bias"][ind_sorted]
            k7["kernel"] = k7["kernel"][:, :, ind_sorted, :]

            self._model.layers[f"{scope}/conv7"].import_weights(k7)
            self._model.layers[f"{scope}/conv6"].import_weights(k6)
            self._model.layers[f"{scope}/conv5"].import_weights(k5)
        else:
            raise ValueError(f"scope {scope} is not supported for reordering weights")
