#!/usr/bin/env python
import itertools

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.lossy_elements.quant_element import MACDataQuantElement
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerEqualizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    EqualizationMode,
    EqualizationPolicy,
    FeaturePolicy,
    LayerEquivType,
    LayerType,
    QuantizationAlgorithms,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import EqualizationError
from hailo_model_optimization.acceleras.utils.stats_export import get_equalization_stats
from hailo_model_optimization.algorithms.equalization.factors_calculator import FactorsCalculator
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

EQUALIZATION_EQUIV = QuantizationAlgorithms.equalization


class Equalization(OptimizationAlgorithm):
    """
    This class is used to equalize network params - most of this code is copied from legacy - _apply_factors is the
    main new function that is new here.
    """

    EPSILON = 1e-30
    DEFAULT_MAX_ACTIVATION_SCALE = 16

    def __init__(self, model, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, name="Equalization", logger_level=logger_level, **kwargs)
        self._calc_noise = {}
        self._hailo_model_flow = model.flow
        self._was_equalized = False
        self._activation_for_equalization = {ActivationType.LINEAR, ActivationType.RELU, ActivationType.LEAKY}
        self._configuration = self.get_algo_config()
        self._factors = dict()
        self._equiv_sets = None
        self._skipped_equiv_sets = None
        self._layers_info = {f"{layer}/source": False for layer in self._model.flow.toposort()}

    @property
    def was_equalized(self):
        return self._was_equalized

    def get_config_by_layer(self):
        algo_cfg = self.get_algo_config()
        config_by_layer = dict()
        for layer_name, layer in self._model.layers.items():
            lname_in_cfg = self.get_layer_name_in_config(layer)
            config_by_layer[layer_name] = algo_cfg.layers.get(lname_in_cfg, LayerEqualizationConfig.get_default())
        return config_by_layer

    def _is_valid_equalization_policy(self, layer_name):
        config_by_layer = self.get_config_by_layer()  # TODO don't recaluclate from scratch
        equalization_policy = config_by_layer[layer_name].policy
        if equalization_policy == EqualizationPolicy.allowed:
            equalization_policy = self._configuration.policy
        if equalization_policy == EqualizationPolicy.disabled:
            return False
        elif equalization_policy == EqualizationPolicy.enabled:
            return True

    def equalize_model(self):
        """
        iteratively go over all equiv_sets in self._equiv_sets and equalize the layers.
        Note: this function must come after  should_skip_algo function so that self._equiv_sets is not None

        """
        if self._equiv_sets is None:
            raise EqualizationError(
                "tried to equalize model before initializing the self._equiv_sets "
                "(may need to run should_skip_algo first ",
            )

        for equiv_set in self._equiv_sets:
            self._equalize_equiv_set(equiv_set)

        if len(self._skipped_equiv_sets) > 0:
            self._summarize_skip_info(self._skipped_equiv_sets)

    def calc_equalization_noise_diff(self):
        # TODO this function is copied from legacy but still not -used
        if self._configuration.mode == EqualizationMode.min_based or len(self._calc_noise) == 0:
            return
        all_current_algo_noise = 0
        all_min_based_noise = 0
        all_no_eq_noise = 0
        for equiv_set, noise_eq in self._calc_noise.items():
            (current_algo_noise, min_based_noise, no_eq_noise) = noise_eq
            all_current_algo_noise += current_algo_noise
            all_min_based_noise += min_based_noise
            all_no_eq_noise += no_eq_noise
            self._logger.debug(
                f"noise for equiv_set. {equiv_set.source.layer.full_name} is {current_algo_noise:.3f}, "
                f"{min_based_noise:.3f}, {no_eq_noise:.3f}",
            )
        self._logger.important(
            f"all noise for {self._configuration.mode.value} "
            f"{all_current_algo_noise:.3f}, all noise for min based  "
            f"{all_min_based_noise:.3f},all noise for no equalization {all_no_eq_noise:.3f}",
        )

    def _equalize_equiv_set(self, equiv_set):
        """
        Apply Equalization to a single equivalence class.

        Finds desired Equalization factors in the equiv class, and uses them to:
         - Modify (by multiplication) this layer's weights (sliced by out-channel) & biases, AND
         - Modify those of **all relevant successors** to have all constraints satisfied (see layer_equiv_set)

        Thus, the network is fully functional after each application of this method (!!)
        """
        self._logger.debug("equalization of equiv set: ")
        self._was_equalized = True
        self._logger.debug(equiv_set.equiv_set_info())
        all_stats, max_concat = self._data_collection(equiv_set)
        config_by_layer = self.get_config_by_layer()
        factor_calculator = FactorsCalculator(equiv_set, all_stats, max_concat, self._configuration, config_by_layer)
        solution = factor_calculator.get_factors()
        self._factors[equiv_set.equiv_set_info()] = solution.factor_solution
        self._calc_noise[equiv_set] = (solution.sol_snr_value, solution.initial_snr_value, solution.no_algo_snr_value)
        self._logger.debug(f"equalization is done for:\n{equiv_set.equiv_set_info()}")
        self._apply_factors(equiv_set, solution.factor_solution)

    def enforce_equiv_set_encoding(self, equiv_set):
        edges_list = list(self._iterate_edges(equiv_set))  # get edges based on subflow based on equiv_set
        self._model.enforce_constraints(edges_list)

    def _iterate_edges(self, equiv_set):
        equiv_set_flow = equiv_set.equiv_set_flow
        for lname in equiv_set_flow.toposort():
            for successor_name in equiv_set_flow.successors(lname):
                yield lname, successor_name

    def _data_collection(self, equiv_set):
        """
        collect all the data we need for equalization in order to be independent of the graph
        Args:
            equiv_set:
        Returns:
            stats_per_layer- a dict where the keys are the layer name and the values are all important information
        """
        max_concat = self.calculate_max_concat_from_equiv_set(equiv_set)
        all_data = {LayerEquivType.consumer: {}, LayerEquivType.producer: {}}
        for equiv_layer in equiv_set.producers + equiv_set.consumers:
            layer_type = equiv_layer.type_of_layer
            stats_per_layer = all_data[layer_type]
            if equiv_layer.layer_name not in stats_per_layer:
                layer = equiv_layer.layer
                stats_per_layer[equiv_layer.layer_name] = get_equalization_stats(
                    layer,
                    layer_type == LayerEquivType.consumer,
                    equiv_layer.input_index,
                )
        return all_data, max_concat

    def _apply_factors(self, equiv_set, output_factors):
        """
        Applying the decisions on the ‘equalization degree of freedom’ - without changing weights.
        Modifies interlayer tensor encodings so that the eventual numeric weights and activation will fit better to
        their grid and lose less info - which is the whole point of the algorithm -
        in consistent and hw-constraints-observant way.
        Currently, we don't change to ZP, leaving scalar for easy support in current (May2022) kernels;
        the factor-optimizing code assumed that.
        In the future, can consider “relaxing” per-channel ZP for further range optimization for cases s.a.
        10<ch1<20, 16<ch2<20

        Note that:
        multiplying the native_kernel by output_factors is equivalent to multiplying the kernel_scale by
        (1/output_factors)  and this is equivalent to multiplying the output_scale by (1/output_factors).
        hence we will denote:
            re_scale_output_factors = 1 / output_factors
        we get:
            output_scale*re_scale_output_factors ==>>> kernel_scale*(1/re_scale_output_factors) ====
            kernel_native*(1/(1/re_scale_output_factors)  == kernel_native*output_factors as needed.

        Args:
            equiv_set: the equiv set
            output_factors: the output_factor

        """
        re_scale_output_factors = 1 / output_factors

        for equiv_layer in equiv_set.sources:
            self._layers_info[f"{equiv_layer.layer.full_name}/source"] = True
            layer = self._model.layers[equiv_layer.layer.full_name]
            type_of_layer = equiv_layer.type_of_layer
            indices = np.array(equiv_layer.layer_indices)
            padded_factors = np.ones_like(layer.output_scale)
            padded_factors[indices] = re_scale_output_factors[equiv_layer.set_indices]

            # update output scales
            layer.set_output_scale(layer.output_scale * padded_factors, 0)
            layer.update_io_ratio()
            layer.update_eq_vec_out(padded_factors)

            self._logger.debug(f"equalizing {type_of_layer} - {equiv_layer.layer.full_name}")

        # after updating the source, we will update all the rest
        self.enforce_equiv_set_encoding(equiv_set)

    def _apply_factors_explicit(self, equiv_set, output_factors):
        """
        is saved only for debugiging
        """
        re_scale_output_factors = 1 / output_factors

        for equiv_layer in equiv_set.sources:
            layer = self._model.layers[equiv_layer.layer.full_name]
            type_of_layer = equiv_layer.type_of_layer
            indices = np.array(equiv_layer.layer_indices)
            padded_factors = np.ones_like(layer.output_scale)
            padded_factors[indices] = re_scale_output_factors[equiv_layer.set_indices]

            # update output scales
            layer.set_output_scale(layer.output_scale * padded_factors, 0)
            layer.update_io_ratio()

            self._logger.debug(f"equalizing {type_of_layer} - {equiv_layer.layer.full_name}")

        for equiv_layer in itertools.chain(equiv_set.consumers, equiv_set.transparents, equiv_set.activations):
            # set all input_scales in equivset
            layer = self._model.layers[equiv_layer.layer.full_name]
            indices = np.array(equiv_layer.layer_indices)
            index = equiv_layer.input_index

            padded_factors = np.ones_like(layer.input_scales[index])
            padded_factors[indices] = re_scale_output_factors[equiv_layer.set_indices]

            # update input scales
            layer.set_input_scale(layer.input_scales[index] * padded_factors, index)
            layer.enforce_io_encoding()

        for equiv_layer in equiv_set.cc_aggregators:
            # update the input scales and output scales for both transparents and activation activations
            layer = self._model.layers[equiv_layer.layer.full_name]
            # TODO: SDK-35765 we need to find the correct input index to fix
            index = equiv_layer.input_index
            start_channel = sum(shape[-1] for shape in layer.input_shapes[:index])
            indices = np.array(np.array(equiv_layer.layer_indices) - start_channel)

            padded_factors = np.ones_like(layer.input_scales[index])
            padded_factors[indices] = re_scale_output_factors[equiv_layer.set_indices]
            layer.set_input_scale(layer.input_scales[index] * padded_factors, index)

            # update output scales
            layer.enforce_io_encoding()

        for equiv_layer in equiv_set.ew_bouncers:
            # update the input scales[1] and output scales
            layer = self._model.layers[equiv_layer.layer.full_name]
            indices = np.array(equiv_layer.layer_indices)
            padded_factors = np.ones_like(layer.input_scales[1])
            padded_factors[indices] = re_scale_output_factors[equiv_layer.set_indices]

            # update input scales
            layer.set_input_scale(layer.input_scales[1] * padded_factors, 1)

            # update output scales
            layer.set_output_scale(layer.output_scale * padded_factors, 0)
            layer.enforce_io_encoding()
            layer.update_io_ratio()

    def _summarize_skip_info(self, skipped_equiv_sets):
        all_unsupported_activations = set()
        all_unsupported_layers = set()

        sources_with_unsupported_activations = set()
        sources_with_unsupported_layers = set()
        sources_with_invalid_policy = set()

        for equiv_set in skipped_equiv_sets:
            sources_names = {layer_name for layer_name in equiv_set.source_layers}
            unsupported_activations = self._get_unsupported_activations(equiv_set)
            invalid_policy = not all(
                self._is_valid_equalization_policy(layer_name) for layer_name in equiv_set.source_layers
            )
            unsupported_layers = {
                (equiv_layer.layer.full_name, equiv_layer.layer.full_name) for equiv_layer in equiv_set.unsupported
            }
            if invalid_policy:
                sources_with_invalid_policy.update(sources_names)
            elif unsupported_activations:
                all_unsupported_activations |= unsupported_activations
                sources_with_unsupported_activations.update(sources_names)
            elif unsupported_layers:
                all_unsupported_layers |= unsupported_layers
                sources_with_unsupported_layers.update(sources_names)

        if sources_with_unsupported_layers:
            self._logger.debug(
                f"Skipped Equalization of layers {sources_with_unsupported_layers} "
                f"because of unsupported layers: {all_unsupported_layers}.",
            )
        if sources_with_unsupported_activations:
            self._logger.debug(
                f"Skipped Equalization of layers {sources_with_unsupported_activations} "
                f"because of unsupported activations: {all_unsupported_activations}.",
            )
        if sources_with_invalid_policy and self._configuration.policy == EqualizationPolicy.enabled:
            self._logger.debug(
                f"Skipped Equalization of layers {sources_with_invalid_policy} "
                f"because their Equalization policy has been set to disabled in the model script.",
            )

    @property
    def _should_equalize_inputs(self):
        """
        there are two condigutaion that will affect inputs equalization.
        1. if there is equalize_inputs policy is allowed - we will go by input_encoding_vector_policy
        2. if there is an explicit equalize_inputs policy (e.g it is disabled/enabled) we will go by it

        """
        equalize_inputs = self._configuration.equalize_inputs
        if equalize_inputs != ThreeWayPolicy.allowed:
            return equalize_inputs == ThreeWayPolicy.enabled
        else:
            # if it is allowed we change will use the input_encoding_vector config
            return self._model_config.globals.input_encoding_vector == FeaturePolicy.enabled

    @property
    def _should_equalize_outputs(self):
        """
        there are two condigutaion that will affect outputs equalization.
        1. if there is equalize_outputs policy is allowed - we will go by output_encoding_vector_policy
        2. if there is an explicit equalize_outputs policy (e.g it is disabled/enabled) we will go by it

        """
        equalize_outputs = self._configuration.equalize_outputs
        if equalize_outputs != ThreeWayPolicy.allowed:
            return equalize_outputs == ThreeWayPolicy.enabled
        else:
            # if it is allowed we change will use the output_encoding_vector config
            return self._model_config.globals.output_encoding_vector == FeaturePolicy.enabled

    def _should_skip_equalization(self, equiv_set):
        unsupported_activations = self._get_unsupported_activations(equiv_set)
        valid_equalization_policies = all(
            self._is_valid_equalization_policy(layer_name) for layer_name in equiv_set.source_layers
        )

        source_input_layers = [
            layer_name
            for layer_name in equiv_set.source_layers
            if isinstance(self._model.layers[layer_name], HailoInputLayer)
        ]
        source_input_const = [
            layer_name
            for layer_name in equiv_set.source_layers
            if isinstance(self._model.layers[layer_name], HailoConst)
        ]
        skip_inputs = len(source_input_layers) > 0 and not (self._should_equalize_inputs)
        skip_const_inputs = len(source_input_const) > 0

        is_dw_with_relu6 = self._is_dw_with_relu6(equiv_set)

        skip_only_kernels_equalization = False
        if self._configuration.mode == EqualizationMode.kernel_equalization:
            (
                has_unsupported_layers,
                pure_8b_equiv_sets,
                multi_sources,
                skip_only_kernels_equalization,
            ) = self._should_skip_only_kernel_equalization(equiv_set)

        skip_outputs = bool(equiv_set.outputs) and not (self._should_equalize_outputs)
        should_skip = (
            (not valid_equalization_policies)
            or (len(unsupported_activations) > 0)
            or equiv_set.unsupported
            or is_dw_with_relu6
            or skip_only_kernels_equalization
            or skip_inputs
            or skip_const_inputs
            or skip_outputs
        )

        if should_skip:
            self._logger.debug(f"Skipping Equalization of equiv set with sources: {equiv_set.equiv_set_info()}.")
            if skip_inputs or skip_const_inputs:
                self._logger.debug(f"Unsupported input layer as source {source_input_layers}")

            if bool(equiv_set.unsupported):
                unsupported_layers = [
                    (equiv_layer.layer.full_name, equiv_layer.layer.full_name) for equiv_layer in equiv_set.unsupported
                ]
                self._logger.debug(f"Unsupported layers {unsupported_layers}")
            if len(unsupported_activations) > 0:
                self._logger.debug(f"Unsupported activations {unsupported_activations}")
            if not valid_equalization_policies:
                message = []
                for layer in equiv_set.source_layers:
                    if not self._is_valid_equalization_policy(layer):
                        continue
                    policy_name = self._configuration.layers.get(
                        layer.full_name,
                        LayerEqualizationConfig.get_default(),
                    ).policy.full_name
                    quantization_groups = layer.precision_config.quantization_groups
                    message.append(
                        f"Layer {layer.full_name}, Policy {policy_name}, Quantization Groups {quantization_groups}",
                    )
                self._logger.debug("Invalid policy layers: \n{}".format("\n".join(message)))
            if is_dw_with_relu6:
                self._logger.debug("Equalization of dw with relu6 is not supported")

            if skip_only_kernels_equalization:
                if has_unsupported_layers:
                    unsupported_layers = [
                        (equiv_layer.layer.full_name, equiv_layer.layer.op.full_name)
                        for equiv_layer in equiv_set.producers + equiv_set.consumers
                        if self._is_supported_layer_of_only_kernels(equiv_layer.layer)
                    ]
                    self._logger.debug(f"Unsupported layers of k-equalization {unsupported_layers}")

                if pure_8b_equiv_sets:
                    self._logger.debug("Equalization with no 4-bit layers is been configured to be skipped")

                if multi_sources:
                    self._logger.debug("Equalization is configured to skip non-1:1")
        return should_skip

    def _should_skip_only_kernel_equalization(self, equiv_set):
        # this function is copied from legacy but yet not used
        layers_to_rescale = [x.layer for x in (equiv_set.producers + equiv_set.consumers)]
        has_unsupported_layers = any([not self._is_supported_layer_of_only_kernels(ll) for ll in layers_to_rescale])
        pure_8b_equiv_sets = self._configuration.skip_8b_to_8b and not any(
            [(ll.conv_op.conv_op.weight_lossy_elements.kernel == MACDataQuantElement(4)) for ll in layers_to_rescale],
        )
        multi_sources = self._configuration.skip_multi_source and (
            len(equiv_set.consumers) > 1 or len(equiv_set.ew_bouncers) > 0
        )
        skip_only_kernels_equalization = has_unsupported_layers or pure_8b_equiv_sets or multi_sources
        return has_unsupported_layers, pure_8b_equiv_sets, multi_sources, skip_only_kernels_equalization

    def _is_supported_layer_of_only_kernels(self, layer):
        # this function is copied from legacy but yet not used
        if layer.op == LayerType.CONV:
            return True
        if layer.op == LayerType.BATCH_NORM and not self._configuration.skip_sbn_layers:
            return True

        return False

    @staticmethod
    def _is_dw_with_relu6(equiv_set):
        dw_equiv_layers = filter(lambda source: isinstance(source.layer, HailoDepthwise), equiv_set.sources)
        unique_dw_set_indices = set()
        for equiv_layer in dw_equiv_layers:
            unique_dw_set_indices.update(equiv_layer.set_indices)

        layers_with_relu_n = [
            equiv_layer
            for equiv_layer in itertools.chain(equiv_set.producers, equiv_set.activations)
            if equiv_layer.layer.get_activation_name() in [ActivationType.RELU6, ActivationType.RELU1]
        ]
        for equiv_layer in layers_with_relu_n:
            if len(set(equiv_layer.set_indices) & unique_dw_set_indices) > 0:
                return True
        return False

    def _get_unsupported_activations(self, equiv_set):
        unsupported_source_activations = self._get_unsupported(equiv_set.sources)
        unsupported_ewb_activations = self._get_unsupported(equiv_set.ew_bouncers)
        unsupported_standalone_activation = self._get_unsupported(equiv_set.activations)

        return unsupported_source_activations | unsupported_ewb_activations | unsupported_standalone_activation

    def _get_unsupported(self, list_of_layers):
        return {
            (layer_equiv.layer.full_name, layer_equiv.layer.get_activation_name())
            for layer_equiv in list_of_layers
            if (
                (layer_equiv.layer.get_activation_name() not in self._activation_for_equalization)
                and layer_equiv.layer.has_activation
            )
        }

    def _get_dense_predecessor(self, dense):
        preds = [i for i in self._hailo_model_flow.predecessors(dense)]
        if len(preds) != 1:
            raise EqualizationError(f"Dense layer '{dense.original_names[0]}' had more than 1 predecessor")
        return preds[0]

    @staticmethod
    def calculate_max_concat_from_equiv_set(equiv_set):
        # TODO this function is copied from legacy but not used - will be removed in the future
        """
        replace simple "layer-wise max" (the one we want to avoid increasing) by maxing over all "concat companions",
        because during "scale matching" (that happens after Equalization) ranges will anyway all be synced to
        the largest one.
        """
        # Note that we take only the max value of layer. The only case that is not good enough is
        # where the min value range absolute value is much bigger than the max value range (this case is very rare)
        if len(equiv_set.concat_layers) != 0:
            EqualizationError("concat layers still not implemented in equalization")

    def get_algo_config(self):
        return self._model_config.equalization

    def _setup(self):
        self._force_transparent()
        self._init_equiv_sets()
        self._configuration.info_config()

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    def _run_int(self):
        self.equalize_model()
        # we must end equalization with scale matching again.

    def _force_transparent(self):
        for lname, layer_config in self._configuration.layers.items():
            if layer_config.force_transparent == ThreeWayPolicy.enabled:
                layer = self._model.layers[lname]
                if not isinstance(layer, HailoElementwiseAdd):
                    self._logger.warning(
                        f"Currently force transparent is supported only for elementwise add layer. ignoring layer {lname}."
                    )
                    continue
                layer.transparent = True

    def _init_equiv_sets(self):
        """
        go iteratively over all the (desired) layers in toposorted order. see iter_equiv_sets
        Equivalence class can be skipped in certain conditions, otherwise they will be added to self._equiv_sets.

        Returns

        """
        equiv_iterator = self._model.iter_equiv_sets(QuantizationAlgorithms.equalization)
        self._skipped_equiv_sets = []
        self._equiv_sets = []
        for equiv_set in equiv_iterator:
            # TODO: Ideally, even when stooped or skip conditions are met, we might be able to equalize the equiv class
            #       partially. Check which channels are affected by the condition, and equalize the rest.
            if self._should_skip_equalization(equiv_set):
                self._skipped_equiv_sets.append(equiv_set)
                # Don't equalize current equiv-set, but keep equalizing component
            else:
                self._equiv_sets.append(equiv_set)

    def should_skip_algo(self):
        return False

    def finalize_global_cfg(self, config):
        # Nothing to do here
        pass

    def export_statistics(self):
        return self._layers_info
