from typing import Dict

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import (
    ModelOptimizationConfig,
    ZeroStaticChannelsConfig,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerZeroStaticChannelsConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    FeaturePolicy,
    LayerFeaturePolicy,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class ZeroStaticChannelsAlgoError(Exception):
    """Generic ZeroStaticChannelsAlgo related exception"""


class ZeroStaticChannelsAlgo(OptimizationAlgorithm):
    """
    Algorithm to eliminate static channels (with low variance).

    Iterate over all layers and verify the output variance of each channel.
    In case an output channel is static (has low variance), we assign zero to the relevant output channel of the kernel
    and assign the mean to the bias. Therefore, the output of the layer is kept "as is" but the weights are easier to
    quantize.
    The goal is to eliminate outliers of the kernel in cases that are not necessary.

    The algorithm will work only for linear or relu activations.
    TODO - open algorithm for all activations - https://hailotech.atlassian.net/browse/SDK-32641
    """

    SUPPORTED_LAYERS = {HailoConv, HailoDepthwise}
    SUPPORTED_ACTIVATIONS = [
        ActivationType.RELU,
        ActivationType.RELU6,
        ActivationType.LINEAR,
        ActivationType.RELU1,
        ActivationType.CLIP,
    ]

    def __init__(self, model: HailoModel, model_config: ModelOptimizationConfig, logger_level, logger=None):
        super().__init__(model, model_config, name="Zero Static Channels", logger_level=logger_level, logger=logger)
        self.eps = self._get_epsilon()

    def log_config(self):
        pass

    def _setup(self):
        config_by_layer = self.get_config_by_layer()
        algo_cfg = self.get_algo_config()
        algo_cfg.layers = config_by_layer

    @staticmethod
    def _get_layer_stat(layer):
        layer_stats = layer.get_output_stats()[0]
        layer_energy = layer_stats.energy
        layer_mean = layer_stats.mean
        layer_variance = layer_energy - layer_mean**2
        return layer_variance, layer_mean

    def _zero_static_channel(self, layer, channel_idx, new_value):
        kernel = layer.get_kernel_np()
        bias = layer.get_bias_np()
        if isinstance(layer, HailoDepthwise):
            kernel_shape = [kernel.shape[0], kernel.shape[1], kernel.shape[3]]
            kernel_max = np.max(np.abs(kernel[:, :, channel_idx, :]))
            kernel[:, :, channel_idx, :] = np.zeros(kernel_shape, dtype=kernel.dtype)
        elif isinstance(layer, HailoConv):
            kernel_shape = kernel.shape[:-1]
            kernel_max = np.max(np.abs(kernel[:, :, :, channel_idx]))
            kernel[:, :, :, channel_idx] = np.zeros(kernel_shape, dtype=kernel.dtype)
        else:
            raise ZeroStaticChannelsAlgoError(
                f"Got unexpected layer type {type(layer)} but supported layers are {self.SUPPORTED_LAYERS}",
            )
        bias[channel_idx] = new_value
        layer.kernel.assign(kernel)
        layer.bias.assign(bias)
        return kernel_max

    def _run_int(self):
        for layer_name in list(self._model.layers.keys()):
            if self.should_skip_layer(layer_name):
                continue
            layer = self._model.layers[layer_name]
            layer_variance, layer_mean = self._get_layer_stat(layer)
            for channel_idx, (channel_variance, channel_mean) in enumerate(zip(layer_variance, layer_mean)):
                if channel_variance < self.eps:
                    kernel_max = self._zero_static_channel(layer, channel_idx, channel_mean)
                    self._logger.debug(
                        f"Zero static channel {channel_idx} of layer {layer_name} with variance "
                        f"{channel_variance:.2f}, mean {channel_mean:.2f} and kernel maximum of "
                        f"{kernel_max:.2f}",
                    )

    def get_algo_config(self):
        return self._model_config.zero_static_channels

    def _get_epsilon(self):
        config = self.get_algo_config()
        return config.eps

    def get_config_by_layer(self) -> Dict[str, ZeroStaticChannelsConfig]:
        """
        Get LayerZeroStaticChannelsConfig for each layer in the model
        """
        layers_config = self.get_algo_config().layers
        cfg_by_layer = dict()
        for layer in self._model.layers.values():
            cfg_lname = layer.full_name
            cfg_by_layer[layer.full_name] = layers_config.get(cfg_lname, LayerZeroStaticChannelsConfig.get_default())
        return cfg_by_layer

    def should_skip_algo(self):
        return not self.has_zero_static()

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg: LayerZeroStaticChannelsConfig) -> LayerZeroStaticChannelsConfig:
        layer = self._model.layers[lname]
        if type(layer) not in self.SUPPORTED_LAYERS or layer.act_op.act_name not in self.SUPPORTED_ACTIVATIONS:
            return LayerZeroStaticChannelsConfig.get_default()
        return cfg

    def has_zero_static(self):
        algo_cfg = self.get_algo_config()
        default_policy = algo_cfg.policy

        any_enabled = any(
            layer_config.policy == LayerFeaturePolicy.enabled for layer_config in algo_cfg.layers.values()
        )

        return any_enabled or (default_policy == FeaturePolicy.enabled)

    @classmethod
    def resolve_policy(cls, default_policy: LayerFeaturePolicy, layer_policy: LayerFeaturePolicy) -> FeaturePolicy:
        """Resolve the feature policy with per layer policy to get the actual layer policy"""
        if layer_policy == LayerFeaturePolicy.allowed:
            policy = FeaturePolicy(default_policy.value)
        else:
            policy = FeaturePolicy(layer_policy.value)
        return policy

    def should_skip_layer(self, layer_name):
        algo_cfg = self.get_algo_config()
        layer_cfg = algo_cfg.layers[layer_name]
        layer_policy = self.resolve_policy(algo_cfg.policy, layer_cfg.policy)
        layer = self._model.layers[layer_name]
        return (
            layer_policy == FeaturePolicy.disabled
            or type(layer) not in self.SUPPORTED_LAYERS
            or layer.act_op.act_name not in self.SUPPORTED_ACTIVATIONS
        )
