#!/usr/bin/env python

import copy
import itertools

import networkx as nx
import numpy as np
from past.utils import old_div

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerEqualizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EqualizationMode,
    EqualizationPolicy,
    LayerEquivType,
    PrecisionMode,
    QuantizationAlgorithms,
)
from hailo_sdk_client.numeric_translator.factors_calculator import FactorsCalculator
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendQuantizationException
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.logger.logger import default_logger

EQUALIZATION_EQUIV = QuantizationAlgorithms.equalization


class ParamsEqualizer:
    """
    This class is used to equalize network params.
    """

    EPSILON = 1e-30
    RELU6_MINIMUM_SCALE = 0.7  # This number is chosen randomly and works
    DEFAULT_MAX_ACTIVATION_SCALE = 16

    def __init__(self, hailo_nn, configuration):
        self._calc_noise = {}
        self._hailo_nn = hailo_nn
        self._equalized_params = None
        self._conv_layer_inference = None
        self._layers_factors = {}
        self._params = None
        self._was_equalized = False
        self._activation_for_equalization = {
            ActivationType.linear,
            ActivationType.relu,
            ActivationType.leaky,
            ActivationType.relu6,
        }
        self._configuration = configuration
        layers_with_4bit_weights = [
            layer.name
            for layer in hailo_nn.stable_toposort()
            if layer.precision_config.precision_mode.reduce() == PrecisionMode.a8_w4
        ]
        self._layers_with_4bit_weights = layers_with_4bit_weights
        self._factors = {}

    @property
    def was_equalized(self):
        return self._was_equalized

    def _is_valid_equalization_policy(self, layer):
        equalization_policy = self._configuration.layers.get(layer.name, LayerEqualizationConfig.get_default()).policy
        if equalization_policy == EqualizationPolicy.allowed:
            equalization_policy = self._configuration.policy
        if equalization_policy == EqualizationPolicy.disabled:
            return False
        elif equalization_policy == EqualizationPolicy.enabled:
            return True

    def equalize_model(self, params, conv_layer_inference):
        self._configuration.info_config()

        self._equalized_params = copy.deepcopy(params)
        self._params = params
        self._conv_layer_inference = conv_layer_inference
        self._join_group_convolutions(self._equalized_params)
        for ind_comp, component in enumerate(self._hailo_nn.components):
            # TODO: match end layers to component. Do we care? probably not.
            comp_start_layers = self._hailo_nn.get_start_layers_of_component(component)
            default_logger().debug(
                f"Starting Equalization of component {ind_comp + 1} out of {nx.number_weakly_connected_components(self._hailo_nn)}",
            )

            self._equalize_component(comp_start_layers)

        self._resplit_group_convolutions(self._equalized_params)

        return self._equalized_params

    def _equalize_component(self, start_layers):
        """
        To be applied iteratively to all (desired) layers in toposorted order. see iter_equiv_sets

        The iteration is stopped when an equivalence class contains unsupported layer or after iteration over all
        equivalence classes.
        Equivalence class can be skipped in certain conditions.

        Equalization is separated to component to prevent complete stop in case of unsupported layers.
        """
        equiv_iterator = self._hailo_nn.iter_equiv_sets(QuantizationAlgorithms.equalization, start_layers)
        skipped_equiv_sets = []
        for equiv_set in equiv_iterator:
            # TODO: Ideally, even when stooped or skip conditions are met, we might be able to equalize the equiv class
            #       partially. Check which channels are affected by the condition, and equalize the rest.
            if self._should_skip_equalization(equiv_set):
                if len(equiv_set.outputs) == 0:
                    skipped_equiv_sets.append(equiv_set)
                # Don't equalize current equiv-set, but keep equalizing component
                continue
            else:
                self._equalize_equiv_set(equiv_set)
        if len(skipped_equiv_sets) > 0:
            self._summarize_skip_info(skipped_equiv_sets)

        self.calc_equalization_noise_diff()

    def calc_equalization_noise_diff(self):
        if self._configuration.mode == EqualizationMode.min_based or len(self._calc_noise) == 0:
            return
        all_current_algo_noise = 0
        all_min_based_noise = 0
        all_no_eq_noise = 0
        for equiv_set, noise_eq in self._calc_noise.items():
            (current_algo_noise, min_based_noise, no_eq_noise) = noise_eq
            all_current_algo_noise += current_algo_noise
            all_min_based_noise += min_based_noise
            all_no_eq_noise += no_eq_noise
            default_logger().debug(
                f"noise for equiv_set. {equiv_set.source.layer.name} is {current_algo_noise:.3f}, "
                f"{min_based_noise:.3f}, {no_eq_noise:.3f}",
            )
        default_logger().important(
            f"all noise for {self._configuration.mode.value} "
            f"{all_current_algo_noise:.3f}, all noise for min based  "
            f"{all_min_based_noise:.3f},all noise for no equalization {all_no_eq_noise:.3f}",
        )

    def _equalize_equiv_set(self, equiv_set):
        """
        Apply Equalization to a single equivalence class.

        Finds desired Equalization factors in the equiv class, and uses them to:
         - Modify (by multiplication) this layer's weights (sliced by out-channel) & biases, AND
         - Modify those of **all relevant successors** to have all constraints satisfied (see layer_equiv_set)

        Thus, the network is fully functional after each application of this method (!!)
        """
        default_logger().debug("equalization of equiv set: ")
        self._was_equalized = True
        default_logger().debug(equiv_set.equiv_set_info())
        all_stats, max_concat = self._data_collection(equiv_set)
        factor_calculator = FactorsCalculator(equiv_set, all_stats, max_concat, self._configuration)
        output_factors, current_algo_sol, min_based_sol, no_equalization_sol = factor_calculator.get_factors()
        self._factors[equiv_set.equiv_set_info()] = output_factors
        self._calc_noise[equiv_set] = (current_algo_sol, min_based_sol, no_equalization_sol)
        default_logger().debug(f"equalization is done for:\n{equiv_set.equiv_set_info()}")
        self._apply_factors(equiv_set, output_factors)

    def _data_collection(self, equiv_set):
        """
        collect all the data we need for equalization in order to be independent from the graph
        Args:
            equiv_set:
        Returns:
            stats_per_layer- a dict where the keys are the layer name and the values are all important information
        """
        max_concat = self.calculate_max_concat_from_equiv_set(equiv_set)
        all_data = {LayerEquivType.consumer: {}, LayerEquivType.producer: {}}
        for equiv_layer in equiv_set.producers + equiv_set.consumers:
            layer_type = equiv_layer.type_of_layer
            stats_per_layer = all_data[layer_type]
            layer = equiv_layer.layer
            layer_inference = self._conv_layer_inference[layer.name]
            kernel = copy.deepcopy(self._equalized_params[layer.name].kernel)
            kernel_before = copy.deepcopy(self._params[layer.name].kernel)
            axes_to_max = layer.get_axes_mask(layer_type)
            kernel, axes_to_max = self._reshape_kernel_dense_if_needed(kernel, axes_to_max, layer, layer_type)
            if equiv_layer.layer_name not in stats_per_layer:
                stats_layer = {}
                stats_layer["full_layer_name"] = layer.name
                stats_layer["kernel"] = kernel
                stats_layer["kernel_before"] = kernel_before
                stats_layer["pre_activation_min"] = layer_inference["stats_min_pre_act_features_value"]
                stats_layer["post_activation_max"] = layer_inference["stats_max_output_features_value"]
                stats_layer["post_activation_min"] = layer_inference["stats_min_output_features_value"]
                stats_layer["input_energy"] = layer_inference["stats_input_energy_features_value"]
                stats_layer["output_energy"] = layer_inference["stats_output_energy_features_value"]
                stats_layer["axes_to_max"] = axes_to_max
                stats_layer["non_zero_percent"] = layer_inference["stats_non_zero_percent_features_value"]
                stats_layer["number_bits"] = 4 if equiv_layer.layer_name in self._layers_with_4bit_weights else 8
                stats_per_layer[equiv_layer.layer_name] = stats_layer

        return all_data, max_concat

    def _apply_factors(self, equiv_set, output_factors):
        input_factors = old_div(1, output_factors)
        for equiv_layer in equiv_set.producers + equiv_set.consumers:
            layer_params = self._equalized_params[equiv_layer.layer.name]

            type_of_layer = equiv_layer.type_of_layer
            axes_mask = np.array(equiv_layer.layer.get_axes_mask(type_of_layer), dtype=np.int32)
            axis = int(np.where(axes_mask == 0)[0])
            axes_mask[axes_mask == 0] = -1

            factors_to_use = input_factors if type_of_layer == LayerEquivType.consumer else output_factors
            padded_factors_kernel, padded_factors_bias = self._get_kernel_and_bias_factors(
                equiv_layer,
                factors_to_use,
                axes_mask,
                layer_params.kernel,
                axis,
            )

            self._equalized_params[equiv_layer.layer.name + "/kernel:0"] = layer_params.kernel * padded_factors_kernel
            if type_of_layer == LayerEquivType.producer:
                self._equalized_params[equiv_layer.layer.name + "/bias:0"] = layer_params.bias * padded_factors_bias
                # if it is a producer it might be in the future a consumer and then we will need the
                # updated output_energy. (this is needed for SQNR calculation) note : its input energy will not change
                self._conv_layer_inference[equiv_layer.layer.name]["stats_output_energy_features_value"] *= (
                    padded_factors_bias**2
                )

                # is a layer that is a producer can be a producer in any other class? say in a split?
                self._conv_layer_inference[equiv_layer.layer.name]["stats_max_output_features_value"] *= (
                    padded_factors_bias
                )
                self._conv_layer_inference[equiv_layer.layer.name]["stats_min_output_features_value"] *= (
                    padded_factors_bias
                )

            if type_of_layer == LayerEquivType.consumer:
                # change the input_energy op a consumer if needed because it might be a producer in the feature.
                padded_factors_input = self._get_bias_factors(equiv_layer, output_factors, layer_params.kernel, axis)

                shape_of_input = self._conv_layer_inference[equiv_layer.layer.name][
                    "stats_input_energy_features_value"
                ].shape
                if len(shape_of_input) > 1:
                    a = self._conv_layer_inference[equiv_layer.layer.name][
                        "stats_input_energy_features_value"
                    ].flatten() * (padded_factors_bias**2)
                    self._conv_layer_inference[equiv_layer.layer.name]["stats_input_energy_features_value"] = a.reshape(
                        shape_of_input,
                    )
                else:
                    self._conv_layer_inference[equiv_layer.layer.name]["stats_input_energy_features_value"] *= (
                        padded_factors_input**2
                    )
            default_logger().debug(f"equalizing {type_of_layer} - {equiv_layer.layer.name}")

    def _summarize_skip_info(self, skipped_equiv_sets):
        all_unsupported_activations = set()
        all_unsupported_layers = set()

        sources_with_unsupported_activations = set()
        sources_with_unsupported_layers = set()
        sources_with_invalid_policy = set()

        for equiv_set in skipped_equiv_sets:
            sources_names = {layer.name for layer in equiv_set.source_layers}
            unsupported_activations = self._get_unsupported_activations(equiv_set)
            invalid_policy = not all(self._is_valid_equalization_policy(layer) for layer in equiv_set.source_layers)
            unsupported_layers = {
                (equiv_layer.layer.name, equiv_layer.layer.op.name) for equiv_layer in equiv_set.unsupported
            }
            if invalid_policy:
                sources_with_invalid_policy.update(sources_names)
            elif unsupported_activations:
                all_unsupported_activations |= unsupported_activations
                sources_with_unsupported_activations.update(sources_names)
            elif unsupported_layers:
                all_unsupported_layers |= unsupported_layers
                sources_with_unsupported_layers.update(sources_names)

        if sources_with_unsupported_layers:
            default_logger().debug(
                f"Skipped Equalization of layers {sources_with_unsupported_layers} "
                f"because of unsupported layers: {all_unsupported_layers}.",
            )
        if sources_with_unsupported_activations:
            default_logger().debug(
                f"Skipped Equalization of layers {sources_with_unsupported_activations} "
                f"because of unsupported activations: {all_unsupported_activations}.",
            )
        if sources_with_invalid_policy and self._configuration.policy == EqualizationPolicy.enabled:
            default_logger().debug(
                f"Skipped Equalization of layers {sources_with_invalid_policy} "
                f"because their Equalization policy has been set to disabled in the model script.",
            )

    def _layer_names_to_graph_layers(self, layer_names):
        if layer_names is None:
            return None
        if isinstance(layer_names, str):
            layer_names = [layer_names]
        layers_of_model = {
            layer_name for layer_name in layer_names if layer_name.split("/")[0] in self._hailo_nn.net_params.net_scopes
        }
        return {self._hailo_nn.get_layer_by_name(layer_name) for layer_name in layers_of_model}

    def _should_skip_equalization(self, equiv_set):
        unsupported_activations = self._get_unsupported_activations(equiv_set)

        valid_equalization_policies = all(
            self._is_valid_equalization_policy(layer) for layer in equiv_set.source_layers
        )

        is_dw_with_relu6 = self._is_dw_with_relu6(equiv_set)

        skip_only_kernels_equalization = False
        if self._configuration.mode == EqualizationMode.kernel_equalization:
            (
                has_unsupported_layers,
                pure_8b_equiv_sets,
                multi_sources,
                skip_only_kernels_equalization,
            ) = self._should_skip_only_kernel_equalization(equiv_set)
        should_skip = (
            (not valid_equalization_policies)
            or (len(unsupported_activations) > 0)
            or equiv_set.unsupported
            or is_dw_with_relu6
            or skip_only_kernels_equalization
        )

        if should_skip:
            default_logger().debug(f"Skipping Equalization of equiv set with sources: {equiv_set.equiv_set_info()}.")
            if bool(equiv_set.unsupported):
                unsupported_layers = [
                    (equiv_layer.layer.name, equiv_layer.layer.op.name) for equiv_layer in equiv_set.unsupported
                ]
                default_logger().debug(f"Unsupported layers {unsupported_layers}")
            if len(unsupported_activations) > 0:
                default_logger().debug(f"Unsupported activations {unsupported_activations}")
            if not valid_equalization_policies:
                message = []
                for layer in equiv_set.source_layers:
                    if not self._is_valid_equalization_policy(layer):
                        continue
                    policy_name = self._configuration.layers.get(
                        layer.name,
                        LayerEqualizationConfig.get_default(),
                    ).policy.name
                    quantization_groups = layer.precision_config.quantization_groups
                    if quantization_groups is None:
                        quantization_groups = 1
                    message.append(
                        f"Layer {layer.name}, Policy {policy_name}, Quantization Groups {quantization_groups}",
                    )
                default_logger().debug("Invalid policy layers: \n{}".format("\n".join(message)))
            if is_dw_with_relu6:
                default_logger().debug("Equalization of dw with relu6 is not supported")

            if skip_only_kernels_equalization:
                if has_unsupported_layers:
                    unsupported_layers = [
                        (equiv_layer.layer.name, equiv_layer.layer.op.name)
                        for equiv_layer in equiv_set.producers + equiv_set.consumers
                        if self._is_supported_layer_of_only_kernels(equiv_layer.layer)
                    ]
                    default_logger().debug(f"Unsupported layers of k-equalization {unsupported_layers}")

                if pure_8b_equiv_sets:
                    default_logger().debug("Equalization with no 4-bit layers is been configured to be skipped")

                if multi_sources:
                    default_logger().debug("Equalization is configured to skip non-1:1")
        return should_skip or bool(equiv_set.outputs)

    def _should_skip_only_kernel_equalization(self, equiv_set):
        layers_to_rescale = [x.layer for x in (equiv_set.producers + equiv_set.consumers)]
        has_unsupported_layers = any(not self._is_supported_layer_of_only_kernels(ll) for ll in layers_to_rescale)
        pure_8b_equiv_sets = self._configuration.skip_8b_to_8b and not any(
            (ll.name in self._layers_with_4bit_weights) for ll in layers_to_rescale
        )
        multi_sources = self._configuration.skip_multi_source and (
            len(equiv_set.consumers) > 1 or len(equiv_set.ew_bouncers) > 0
        )
        skip_only_kernels_equalization = has_unsupported_layers or pure_8b_equiv_sets or multi_sources
        return has_unsupported_layers, pure_8b_equiv_sets, multi_sources, skip_only_kernels_equalization

    def _is_supported_layer_of_only_kernels(self, layer):
        if layer.op == LayerType.conv:
            return True
        if layer.op == LayerType.batch_norm and not self._configuration.skip_sbn_layers:
            return True

        return False

    @staticmethod
    def _is_dw_with_relu6(equiv_set):
        dw_equiv_layers = filter(lambda source: source.layer.op == LayerType.dw, equiv_set.sources)
        unique_dw_set_indices = set()
        for equiv_layer in dw_equiv_layers:
            unique_dw_set_indices.update(equiv_layer.set_indices)

        layers_with_relu_n = [
            equiv_layer
            for equiv_layer in itertools.chain(equiv_set.producers, equiv_set.activations)
            if equiv_layer.layer.activation in [ActivationType.relu6, ActivationType.relu1]
        ]
        return any(len(set(equiv_layer.set_indices) & unique_dw_set_indices) > 0 for equiv_layer in layers_with_relu_n)

    def _get_unsupported_activations(self, equiv_set):
        source_activations = {
            (source.name, source.activation.name)
            for source in equiv_set.source_layers
            if source.activation not in self._activation_for_equalization
        }
        ewb_activations = {
            (ewb.layer.name, ewb.layer.activation.name)
            for ewb in equiv_set.ew_bouncers
            if ewb.layer.activation not in self._activation_for_equalization
        }
        standalone_activation_supported = {
            (standalone.layer.name, standalone.layer.activation.name)
            for standalone in equiv_set.activations
            if standalone.layer.activation not in self._activation_for_equalization
        }

        return source_activations | ewb_activations | standalone_activation_supported

    def _join_group_convolutions(self, params):
        for hn_item in self._hailo_nn:
            if hn_item.op not in [LayerType.conv, LayerType.deconv] or hn_item.groups == 1:
                continue
            kernel = params[hn_item.name].kernel
            dense_kernel = np.zeros(
                (kernel.shape[0], kernel.shape[1], kernel.shape[2] * hn_item.groups, kernel.shape[3]),
                kernel.dtype,
            )
            in_group_size = dense_kernel.shape[2] // hn_item.groups
            out_group_size = dense_kernel.shape[3] // hn_item.groups
            for i in range(hn_item.groups):
                cur_kernel = kernel[:, :, :, out_group_size * i : out_group_size * (i + 1)]
                dense_kernel[
                    :,
                    :,
                    in_group_size * i : in_group_size * (i + 1),
                    out_group_size * i : out_group_size * (i + 1),
                ] = cur_kernel

            params.set_layer_kernel(hn_item.name, dense_kernel)

    def _resplit_group_convolutions(self, params):
        for hn_item in self._hailo_nn:
            if hn_item.op not in [LayerType.conv, LayerType.deconv] or hn_item.groups == 1:
                continue
            dense_kernel = params[hn_item.name].kernel
            in_group_size = dense_kernel.shape[2] // hn_item.groups
            out_group_size = dense_kernel.shape[3] // hn_item.groups
            kernel = np.zeros(
                (dense_kernel.shape[0], dense_kernel.shape[1], in_group_size, dense_kernel.shape[3]),
                dense_kernel.dtype,
            )
            for i in range(hn_item.groups):
                cur_kernel = dense_kernel[
                    :,
                    :,
                    in_group_size * i : in_group_size * (i + 1),
                    out_group_size * i : out_group_size * (i + 1),
                ]
                kernel[:, :, :, out_group_size * i : out_group_size * (i + 1)] = cur_kernel

            params.set_layer_kernel(hn_item.name, kernel)

    def _get_dense_predecessor(self, dense):
        preds = list(self._hailo_nn.predecessors(dense))
        if len(preds) != 1:
            raise BackendQuantizationException(f"Dense layer '{dense.original_names[0]}' had more than 1 predecessor")
        return preds[0]

    def calculate_max_concat_from_equiv_set(self, equiv_set):
        """
        replace simple "layerwise max" (the one we want to avoid increasing) by maxing over all "concat companions",
        because during "scale matching" (that happens after Equalization) ranges will anyways all be synced to
        largest one.
        """
        # Note that we take only the max value of layer. The only case that is not good enough is
        # where the min value range absolute value is much bigger than the max value range (this case is very rare)
        if len(equiv_set.concat_layers) != 0:
            return np.max(
                [
                    np.max(self._conv_layer_inference[src_name]["stats_max_output_features_value"])
                    for src_name in equiv_set.concat_layers
                ],
            )
        return None

    def _reshape_kernel_dense_if_needed(self, kernel, axes_to_max, layer, type_of_layer):
        # handle a specific case of the
        if layer.op == LayerType.dense and type_of_layer == LayerEquivType.consumer:
            pred = self._get_dense_predecessor(layer)
            if len(pred.output_shape) == 4:
                kernel = np.reshape(kernel, [pred.output_width * pred.output_height, -1, layer.output_features])
                axes_to_max = [True, False, True]
        return kernel, axes_to_max

    def _get_bias_factors(self, equiv_layer, factors, kernel, axis):
        indices = np.array(equiv_layer.layer_indices)
        if equiv_layer.type_of_layer == LayerEquivType.consumer and equiv_layer.layer.op == LayerType.dense:
            pred = self._get_dense_predecessor(equiv_layer.layer)
            # layer has spatial properties, e.g. conv-to-dense
            if len(pred.output_shape) == 4:
                input_tensor_area = pred.output_width * pred.output_height
                shape_for_kernel = old_div(kernel.shape[0], input_tensor_area)
                scales_before = np.ones(shape_for_kernel, kernel.dtype)
                scales_before[indices] = factors[equiv_layer.set_indices]
                return np.tile(np.squeeze(scales_before), input_tensor_area)

        padded_factors = np.ones(kernel.shape[axis], kernel.dtype)
        padded_factors[indices] = factors[equiv_layer.set_indices]
        return padded_factors

    def _get_kernel_and_bias_factors(self, equiv_layer, factors, axes_mask, kernel, axis):
        bias_factor = self._get_bias_factors(equiv_layer, factors, kernel, axis)
        kernel_factors = np.reshape(bias_factor, axes_mask)
        return kernel_factors, bias_factor
