import logging
from typing import Dict, Tuple

import networkx as nx

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv_add import BaseHailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerTranslationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, FeaturePolicy
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasConfigurationError
from hailo_model_optimization.algorithms.equiv_matching.matching_algo import MatchingAlgo
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class CreateIOEncoding(OptimizationAlgorithm):
    """
    This class creates the io encoding of the layers necessary for inferring a quantized model.
    It creates io scale and zero point for layers with independent output scale (e.g. conv based layers)
    Additionally, this class handles the force_range_{in,out} commands, and finds the independent scale of each layer
    """

    forced_range_by_source_layer: Dict[str, Tuple[float, float]]
    source_layer_dependants: Dict[str, str]

    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="Create Encoding", logger_level=logger_level, logger=logger)
        self._matching_algo = MatchingAlgo(model, model_config, logging.DEBUG)

    def _setup(self):
        retval = super()._setup()
        self._matching_algo._setup()
        # NOTE: temporary hotfix for the faulty defaults removal
        for lname, lcfg in self.get_algo_config().layers.items():
            if lcfg.max_elementwise_feed_repeat is not None and not isinstance(
                self._model.layers[lname],
                BaseHailoConvAdd,
            ):
                lcfg.max_elementwise_feed_repeat = None
        algo_config = self.get_algo_config()
        for layer in self._model.layers:
            algo_config.layers.setdefault(layer, LayerTranslationConfig.get_default())
        self._model.import_config(
            self._model_config,
            force_translation=True,
        )  # import the translation config after they've been handled
        self.forced_range_by_source_layer = dict()
        self.source_layer_dependants = dict()
        return retval

    def should_skip_algo(self):
        return False

    def get_algo_config(self):
        return self._model_config.translation_config

    def log_config(self):
        pass

    @property
    def equiv_match(self):
        return self._matching_algo.equiv_match

    def _run_int(self):
        for matching_component_group in self.equiv_match.get_groups_components():
            self._find_range_of_component_group(matching_component_group)
            self._create_encoding_component_group(matching_component_group)
            self._matching_algo.match_components_group(matching_component_group, training=False)
        for output_layer in self._model.flow.output_nodes:
            layer = self._model.layers[output_layer]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer.enforce_io_encoding()

    def _find_range_of_component_group(self, component_group):
        """Updates a dict with a forced range based on source layers of components group"""
        source_layers = self.equiv_match.source_layers_group(component_group)
        toposorted_nodes = self.equiv_match.get_sorted_u_nodes_in_componenets_group(component_group)
        consumer_layers = self.equiv_match.consumer_layers_groups(component_group)
        algo_config = self.get_algo_config()
        # TODO: do we need to get per layer default for translation cfg?
        # update range based on source layers
        for lname in source_layers:
            layer_cfg = algo_config.layers[lname]
            if layer_cfg.force_range_out is not None:
                new_range = layer_cfg.force_range_out
                self._update_desired_forced_range(lname, [lname], new_range)
            elif layer_cfg.input_normalization == FeaturePolicy.enabled and isinstance(
                self._model.layers[lname], HailoInputLayer
            ):
                input_lossy_element = self._model.layers[lname].get_input_lossy_elements()[0]
                if input_lossy_element.bits == 8 and not input_lossy_element.signed:
                    new_range = [input_lossy_element.min_value, input_lossy_element.max_value]
                    # TODO: uncomment this when we want to enable this feature. SDK-51539
                    # self._update_desired_forced_range(lname, [lname], new_range)

        # update range based on transparent layers
        for lname in toposorted_nodes:
            layer_cfg = algo_config.layers[lname]
            if lname in source_layers or lname in consumer_layers:
                continue
            # TODO - if layer is feature multiplier, make sure 1 is in range
            if layer_cfg.force_range_in is not None:
                new_range = layer_cfg.force_range_in
                self._update_desired_forced_range(lname, source_layers, new_range, layer_cfg.force_range_index)
            if layer_cfg.force_range_out is not None:
                new_range = layer_cfg.force_range_out
                self._update_desired_forced_range(lname, source_layers, new_range)

        # update range based on consumer layers
        for lname in consumer_layers:
            layer_cfg = algo_config.layers[lname]
            if layer_cfg.force_range_in is not None:
                new_range = layer_cfg.force_range_in
                self._update_desired_forced_range(lname, source_layers, new_range, layer_cfg.force_range_index)

        # update range of output layers with sigmoid activation to be [0,1]
        end_node_exist_in_toposorted = set(toposorted_nodes) & set(self._model.flow.output_layer_order)
        for lname in source_layers + consumer_layers:
            layer = self._model.layers[lname]
            succesors = set(self._model.flow.successors(lname))
            if (
                not isinstance(layer, BaseHailoNonNNCoreLayer)
                and layer.get_activation_name() is ActivationType.SIGMOID
                and len(end_node_exist_in_toposorted) > 0
                and lname not in self.forced_range_by_source_layer.keys()
                and succesors & set(self._model.flow.get_end_nodes())
                and len(succesors) == 1
            ):
                # TODO If we want to support more succesors we need to check how
                # TODO this change will affect all the component Group
                range_min, range_max = layer.get_output_limvals()[0]
                if range_max < 1.0 or range_min > 0.0:
                    self._logger.info(
                        f"Output layer {lname} with sigmoid activation was detected. "
                        f"Forcing its output range to be [0, 1] (original range was [{range_min}, {range_max}])."
                    )

                    self._update_desired_forced_range(lname, source_layers, [0.0, 1.0])

    def _create_encoding_component_group(self, component_group):
        """Creates encoding for all source layers in components group"""
        source_layers = self.equiv_match.source_layers_group(component_group)
        for source_layer_name in source_layers:
            current_range = self.forced_range_by_source_layer.get(source_layer_name)
            source_layer = self._model.layers[source_layer_name]
            if isinstance(source_layer, BaseHailoNonNNCoreLayer):
                continue
            layer_cfg = self.get_algo_config().layers.get(source_layer.full_name, None)
            source_layer.enforce_io_encoding()  # either do nothing or apply the "output_scale_scalar_dof"
            source_layer.create_output_encoding_candidates(current_range, translation_config=layer_cfg)

    def _update_force_range_strong(self, target_layer, source_layer, new_range):
        algo_config = self.get_algo_config()
        target_layer_cfg = algo_config.layers[target_layer]
        if target_layer_cfg.weak_force_range_out == FeaturePolicy.enabled:
            return
        self._model.layers[source_layer].strong_force_range = new_range

    def _update_desired_forced_range(self, target_layer, source_layers, new_range, force_range_index=None):
        """Update the desired forced range of the source layers connected to the target layer"""
        if force_range_index is not None:
            real_target = self._model.flow.predecessors_sorted(target_layer)[force_range_index]
        else:
            real_target = target_layer
        source_layers = self.find_connected_source_layers(real_target, source_layers)
        for source_layer in source_layers:
            self.forced_range_by_source_layer.setdefault(source_layer, new_range)
            if self.forced_range_by_source_layer[source_layer] != new_range:
                previous_target = self.source_layer_dependants[source_layer]
                raise AccelerasConfigurationError(
                    f"Force range of layers {target_layer} and {previous_target} have conflicting ranges",
                )
            self.forced_range_by_source_layer[source_layer] = new_range
            # Taking the latest layer each time. The error won't print all the conflicting layers
            self.source_layer_dependants[source_layer] = target_layer
            self._update_force_range_strong(target_layer, source_layer, new_range)

    def find_connected_source_layers(self, target_layer, source_layers):
        connected_sources = []
        for source_layer in source_layers:
            if nx.has_path(self._model.flow, source_layer, target_layer):
                connected_sources.append(source_layer)
        return connected_sources

    def _deprication_warning(self):
        algo_config = self.get_algo_config()
        layers_force_range = []
        for layer_n in algo_config.layers:
            if (
                algo_config.layers[layer_n].meta is not None
                and algo_config.layers[layer_n].meta.get("force_range_out") is not None
                and algo_config.layers[layer_n].weak_force_range_out != FeaturePolicy.enabled
            ):
                layers_force_range.append(layer_n)
        if len(layers_force_range) > 0:
            self._logger.warning(
                f"The force_range command has been used, notice that its behavior was "
                f"changed on this version. The old behavior forced the range on the "
                f"collected calibration set statistics, but allowed the range to change "
                f"during the optimization algorithms.\n"
                f"The new behavior forces the range throughout all optimization stages.\n"
                f"The old method could be restored by adding the flag "
                f"weak_force_range_out=enabled to the force_range command on the following "
                f"layers {layers_force_range}"
            )

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        if (
            "max_elementwise_feed_repeat" in cfg
            and cfg["max_elementwise_feed_repeat"] is not None
            and not isinstance(self._model.layers[lname], BaseHailoConvAdd)
        ):
            # can't apply max_elementwise_feed_repeat when the layer is not conv and add
            cfg.pop("max_elementwise_feed_repeat", None)
        return cfg
