import logging
from typing import Dict

import numpy as np
from pydantic.v1 import Field

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_norm import HailoLayerNorm
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplitPixels
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ZP_FEED_REPEAT,
    EncodingMatchType,
    MatmulCorrectionType,
    OpStates,
)
from hailo_model_optimization.algorithms.algorithm_base import AlgoResults
from hailo_model_optimization.algorithms.equiv_matching.matching_algo import MatchingAlgo
from hailo_model_optimization.algorithms.neg_exponent_fixer.neg_exp_fixer import NegExponentFixer
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

DEBUG = False


class CreateHWParamsResults(AlgoResults):
    shifts_by_layer: Dict[str, int] = Field({}, description="Negative slope exponent correction shift")


class CreateHWParamsWithMatch(OptimizationAlgorithm):
    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="Translate Parameters", logger_level=logger_level, logger=logger)
        self._matching_algo = MatchingAlgo(model, model_config, logging.DEBUG)
        self._results = CreateHWParamsResults()
        self._input_scale_updated_layers = []
        self._weights_clipping_statistics = dict()

    def _setup(self):
        self._matching_algo._setup()

    def should_skip_algo(self):
        return False

    def get_algo_config(self):
        return self._model_config.weights_clipping

    def log_config(self):
        pass

    @staticmethod
    def get_scale_in_out(layer, inverse=False):
        input_scale = np.array(layer.input_scales[0])
        output_scale = np.array(layer.output_scales[0])
        if len(np.unique(input_scale)) == 1:
            sc_in = input_scale[0]
        else:
            index_group = len(input_scale) // 3
            sc_in = np.array([input_scale[0], input_scale[index_group], input_scale[-1]])

        if len(np.unique(output_scale)) == 1:
            sc_out = output_scale[0]
        else:
            index_group = len(output_scale) // 3
            sc_out = np.array([output_scale[0], output_scale[index_group], output_scale[-1]])
        ratio = sc_out / sc_in
        if inverse:
            ratio = sc_in / sc_out
        print(layer.full_name, "scale_in", sc_in, "scale_out", sc_out, "ratio", ratio)
        return sc_in, sc_out

    @staticmethod
    def get_scale_in_out_ew_mult(layer):
        if len(layer.input_scales) == 1:
            input_scale_accs = layer.input_scales[0] ** 2
        else:
            input_scale_accs = layer.input_scales[0] * layer.input_scales[1]
        input_scale = np.array(input_scale_accs)

        output_scale = np.array(layer.output_scales[0])

        if len(np.unique(input_scale)) == 1:
            sc_in = input_scale[0]
        else:
            index_group = len(input_scale) // 3
            sc_in = np.array([input_scale[0], input_scale[index_group], input_scale[-1]])

        if len(np.unique(output_scale)) == 1:
            sc_out = output_scale[0]
        else:
            index_group = len(output_scale) // 3
            sc_out = np.array([output_scale[0], output_scale[index_group], output_scale[-1]])
        ratio = sc_out / sc_in
        print(layer.full_name, "scale_acc", sc_in, "scale_out", sc_out, "ratio", ratio)
        return sc_in, sc_out

    def _info_debug(self, lname, layer):
        if ("equalization_consumer_out" in layer.full_name) and DEBUG:
            model_name = lname.split("/")[0]
            layer_name_no_scope = lname.split("/")[1]
            layer_norm_name = ("_").join(layer_name_no_scope.split("_")[-2:])
            # "equalization_source_"
            # "shortcut1_low_",
            # "shortcut2_high_",
            # "ew_mullt_inter_",
            # "concat_layer_",
            # "square1_",
            # "square2_"
            print()
            print()

            layer_source = self._model.layers[f"{model_name}/equalization_source_{layer_norm_name}"]
            shortcut1_low = self._model.layers[f"{model_name}/shortcut1_low_{layer_norm_name}"]
            shortcut2_high = self._model.layers[f"{model_name}/shortcut2_high_{layer_norm_name}"]

            square1_low = self._model.layers[f"{model_name}/square1_{layer_norm_name}"]
            square2_high = self._model.layers[f"{model_name}/square2_{layer_norm_name}"]
            ew_mullt_inter = self._model.layers[f"{model_name}/ew_mult_inter_{layer_norm_name}"]

            normalization_nudge = self._model.layers[f"{model_name}/normalization_nudge_{layer_norm_name}"]
            equalization_consumer = self._model.layers[f"{model_name}/equalization_consumer_out_{layer_norm_name}"]

            print("layer_name", layer.full_name)
            print("#####################################")

            _ = self.get_scale_in_out(layer_source)
            _ = self.get_scale_in_out(normalization_nudge, inverse=True)
            _ = self.get_scale_in_out(equalization_consumer, inverse=True)
            print("#####################################")
            print()

            s_in_low, s_out_low = self.get_scale_in_out(shortcut1_low)
            s_in_high, s_out_high = self.get_scale_in_out(shortcut2_high)

            print("ratio_low_high", s_out_low / s_out_high)

            s_in_square1_low, s_out_square1_low = self.get_scale_in_out_ew_mult(square1_low)
            s_in_square2_high, s_out_square2_high = self.get_scale_in_out_ew_mult(square2_high)
            s_in_ew_mullt_inter, s_out_ew_mullt_inter = self.get_scale_in_out_ew_mult(ew_mullt_inter)

            print("ratio_high_low", s_out_square2_high / s_out_square1_low)
            print("ratio_high_high", s_out_square2_high / s_out_square2_high)
            print("ratio_high_inter", s_out_square2_high / s_out_ew_mullt_inter)

            print()

            successor_square1_low = self._model.layers[self._model.flow.successors_sorted(square1_low.full_name)[0]]
            s_in_pc_0, s_out_pc_0 = self.get_scale_in_out(successor_square1_low, inverse=True)

            successor_square2_high = self._model.layers[self._model.flow.successors_sorted(square2_high.full_name)[0]]
            s_in_pc_1, s_out_pc_1 = self.get_scale_in_out(successor_square2_high, inverse=True)

            successor_ew_mullt_inter = self._model.layers[
                self._model.flow.successors_sorted(ew_mullt_inter.full_name)[0]
            ]
            s_in_pc_2, s_out_pc_2 = self.get_scale_in_out(successor_ew_mullt_inter, inverse=True)

            print("ratio_high_low1 preccccc", s_out_pc_1 / s_out_pc_0)
            print("ratio_high_high2 preccccc", s_out_pc_1 / s_out_pc_1)
            print("ratio_high_inter3 preccccc", s_out_pc_1 / s_out_pc_2)

            index = normalization_nudge.input_scales[0].shape[0] // 3
            s1 = normalization_nudge.input_scales[0][0]
            s2 = normalization_nudge.input_scales[0][index]
            s3 = normalization_nudge.input_scales[0][-1]
            # print(s1,s2,s3)
            print("ratio_s2/s1", s2 / s1)
            print("ratio_s2/s2", s2 / s2)
            print("ratio_s2/s3", s2 / s3)

            index = normalization_nudge.output_scale.shape[0] // 3
            s1 = normalization_nudge.output_scale[0]
            s2 = normalization_nudge.output_scale[index]
            s3 = normalization_nudge.output_scale[-1]
            # print(s1,s2,s3)
            print("ratio_s2/s1_out", s2 / s1)
            print("ratio_s2/s2_out", s2 / s2)
            print("ratio_s2/s3_out", s2 / s3)

            print()

    def _create_hw_params_component(self, matching_component_group):
        """
        Creates hw params to a matching component and fixes negative slope exponent if needed

        Args:
            matching_component_group: component with layers for which the hw params will be created

        Returns:
            Boolean, if the hw params creation should be restarted.
            It should be restarted when there's scale matching

        """
        toposorted_nodes = self.equiv_match.get_sorted_u_nodes_in_componenets_group(matching_component_group)
        for lname in toposorted_nodes:
            layer = self._model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer.enable_lossy()
            self._model_config.weights_clipping.layers.setdefault(lname, LayerWeightsClippingConfig.get_default())
            layer_clip_cfg = self._model_config.weights_clipping.layers[lname]
            force_shift = self._model_config.translation_config.layers[lname].force_shift
            hw_shifts = [force_shift] if force_shift is not None else force_shift
            layer.create_hw_params(layer_clip_cfg, self.optimization_target, hw_shifts=hw_shifts)
            self._info_debug(lname, layer)

            weight_limvals = layer.get_weights_clipping()
            if weight_limvals is not None:
                self._weights_clipping_statistics[f"{lname}/clip_values"] = weight_limvals
            retry_negative_exp_list = self._hanlde_negative_exponent(layer, matching_component_group)
            retry_shift_delta_list = self._handle_shift_delta(layer)
            retry_list = retry_negative_exp_list + retry_shift_delta_list
            if len(retry_list) > 0:
                return retry_list
        return []

    def _hanlde_negative_exponent(self, layer, matching_component_group) -> list:
        """Calls for the negative exponent fixer algorithm and handles the matching if there is a need to do so"""
        algo = NegExponentFixer(
            self._model,
            self._model_config,
            layer.full_name,
            logger=self._logger,
            logger_level=self._logger_level,
        )
        algo.run()

        res = algo.get_results()
        results = []

        # Take care of the equivflow
        if layer.full_name != res.layer_fix.layer_name:
            matching_component_group = self.equiv_match.replace_component_group_source(
                matching_component_group,
                layer.full_name,
                res.layer_fix.layer_name,
            )

            for successor in self._model.flow.successors_sorted(res.layer_fix.layer_name):
                self.equiv_match.replace_node((layer.full_name, successor), (res.layer_fix.layer_name, successor))

        if res.layer_fix.scale_shift > 0:
            self._add_result(layer, res.layer_fix.scale_shift)

            # Need to update the matching component group
            layer = self._model.layers[res.layer_fix.layer_name]
            layers = self.equiv_match.consumer_layers_groups(matching_component_group)
            for lname_consumers in layers:
                consumer_layer = self._model.layers[lname_consumers]
                consumer_layer.update_scale_scalar_dof(res.layer_fix.scale_shift)
            # Assuming topological order, if any source is changed
            # it should affect all it's successors before we handle them
            match_type = self._matching_algo.match_components_group(matching_component_group)
            if match_type == EncodingMatchType.SCALE_MATCH:
                results = [matching_component_group]
        return results

    def _handle_shift_delta(self, layer) -> list:
        shift_delta = layer.shift_delta
        # TODO: calculate input expand

        self._logger.debug(f"{layer.full_name} shift_delta {shift_delta}")

        if shift_delta is None:
            return []
        fix_shift = 2 ** np.max(shift_delta)
        if fix_shift == 1:
            return []
        if layer.full_name in self._input_scale_updated_layers:
            return []

        groups_components = self.equiv_match.get_groups_components()
        input_shift_delta = self._calculate_input_shift_delta(layer, fix_shift, groups_components)
        if input_shift_delta == 1:
            return []
        layers_to_update = set()
        comp_group_to_update = []
        for comp_group in groups_components:
            if layer.full_name in self.equiv_match.consumer_layers_groups(comp_group):
                layers_to_update = layers_to_update.union(self.equiv_match.layers_group(comp_group))
                comp_group_to_update.append(comp_group)

        comp_to_update = []
        for comp_group in groups_components:
            for layer in layers_to_update:
                if layer in self.equiv_match.source_layers_group(comp_group):
                    if np.any(
                        [
                            isinstance(self._model.layers[l_name], HailoLayerNorm)
                            for l_name in self.equiv_match.layers_group(comp_group)
                        ],
                    ):
                        return []
                    if comp_group not in comp_to_update:
                        comp_to_update.append(comp_group)

        source_handeld = set()
        for comp_group in comp_group_to_update:
            for source_layer_name in self.equiv_match.source_layers_group(comp_group):
                source_layer = self._model.layers[source_layer_name]
                if source_layer_name not in source_handeld:
                    source_layer.set_output_scale(source_layer.output_scale * input_shift_delta, index=0)
                source_handeld.add(source_layer_name)
            self._matching_algo.match_components_group(comp_group)

        return comp_to_update

    def _calculate_input_shift_delta(self, layer, fix_shift, groups_components):
        self._input_scale_updated_layers.append(layer.full_name)
        # if there is a input in the sources, return 1
        for comp in groups_components:
            if layer.full_name in self.equiv_match.consumer_layers_groups(comp):
                for name in self.equiv_match.source_layers_group(comp):
                    if name in self._model.flow.input_nodes:
                        return 1
                    if isinstance(self._model.layers[name], HailoLayerNorm):
                        return 1
                for name in self.equiv_match.layers_group(comp):
                    if isinstance(self._model.layers[name], HailoPrecisionSplitPixels):
                        self._logger.warning(
                            f"input_shift_delta for {layer.full_name} isn't supported when the previous layer is a "
                            f"precision split layer (shift delta: {fix_shift}).\n"
                            f"This might cause some degradation as a result.\n"
                            f"To avoid this, please consider defusing the layer input into more groups, "
                            f"or removing the conv_a16_w4 decomposition command."
                        )
                        return 1
        if isinstance(layer, BaseHailoConv):
            if layer.conv_op.set_scale_by_kernel_only:
                return fix_shift
            weights_bits = layer.conv_op.weight_lossy_elements.kernel.bits
            input_bits = layer.conv_op.input_lossy_elements[0].bits
            input_fix = np.sqrt(fix_shift * 2 ** (input_bits - weights_bits))
            input_fix = np.minimum(input_fix, fix_shift)
            return input_fix
        elif isinstance(layer, HailoElementwiseMultOnMac):
            return fix_shift
        return 1

    @property
    def equiv_match(self):
        return self._matching_algo.equiv_match

    def _add_result(self, layer, fix_shift):
        self._results.shifts_by_layer[layer.full_name] = fix_shift

    def _run_int(self):
        matching_component_group_stack = self.equiv_match.get_groups_components()
        while len(matching_component_group_stack) > 0:
            matching_component_group = matching_component_group_stack.pop(0)
            self._matching_algo.match_components_group(matching_component_group)
            comp_to_retry = self._create_hw_params_component(matching_component_group)
            for comp in comp_to_retry:
                if comp not in matching_component_group_stack:
                    matching_component_group_stack.insert(0, comp)
        self._matching_algo.update_output_scales()
        self.set_calculate_zp_comp()
        self._model.add_supported_state(OpStates.QUANTIZED)

    def set_calculate_zp_comp(self):
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            if isinstance(layer, HailoMatmul) and np.any(layer.input_zero_points[0] != 0):
                if layer.zp_correction_type == MatmulCorrectionType.ZP_COMP_WEIGHTS:
                    conv_to_correct_name = self._model.flow.predecessors_sorted(lname)[1]
                    conv_to_correct = self._model.layers[conv_to_correct_name]
                    self._logger.debug(
                        f"set out_zp_comp_groups to be {layer.groups} on layer "
                        f"{conv_to_correct_name} because of {lname}",
                    )
                    conv_to_correct.conv_op.out_zp_comp_groups = layer.groups
                    conv_to_correct.conv_op.feed_repeat = ZP_FEED_REPEAT
                    layer.matmul_op.feed_repeat = ZP_FEED_REPEAT
                elif layer.zp_correction_type in [
                    MatmulCorrectionType.ZP_COMP,
                    MatmulCorrectionType.ZP_COMP_BLOCK,
                    MatmulCorrectionType.ZP_COMP_BLOCK_2,
                    MatmulCorrectionType.ZP_COMP_BLOCK_3,
                ]:
                    pass  # matmul_op.feed_repeat was allready update!
                elif layer.zp_correction_type == MatmulCorrectionType.ZP_COMP_NONE:
                    pass
                else:
                    raise ValueError(f"zp_correction_type {layer.zp_correction_type} is not supported")

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        # TODO: add indication if weights clipping is supported
        return cfg

    def export_statistics(self):
        return self._weights_clipping_statistics
