from copy import deepcopy
from typing import Any, Optional, Tuple

import numpy as np
from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasMode, OpStates, PrecisionMode
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class TransitionInfo(BaseModel):
    """Helper class to store the information of a layer"""

    accumulator_scales: Optional[Any] = Field(description="The accumulator scales of the Original layer")
    input_scales: Optional[Any] = Field(description="The input scales of the Original layer")
    output_scales: Optional[Any] = Field(description="The output scales of the Original layer")
    input_zp: Optional[Any] = Field(description="The input zero point of the Original layer")
    output_zp: Optional[Any] = Field(description="The output zero point of the Original layer")

    class Config:
        arbitrary_types_allowed = True


class LayerSplitter:
    """
    Class in charge of splitting a layer on a translated state into different
    layers also on translated state.
    """

    activation_mapping = {
        PrecisionMode.a16_w16_a16: (PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a16),
        PrecisionMode.a8_w8_a8: (PrecisionMode.a8_w8_a16, PrecisionMode.a16_w16_a8),
        PrecisionMode.a8_w8_a16: (PrecisionMode.a8_w8_a16, PrecisionMode.a16_w16_a16),
        PrecisionMode.a16_w16_a8: (PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8),
        PrecisionMode.a8_w4_a8: (PrecisionMode.a8_w4_a16, PrecisionMode.a16_w16_a8),
        PrecisionMode.a8_w4_a8: (PrecisionMode.a8_w4_a16, PrecisionMode.a16_w16_a8),
    }

    def __init__(self, model: HailoModel, model_config: ModelOptimizationConfig, logger):
        self._model = model
        self._config = model_config
        self._logger = logger

    @property
    def optimization_target(self):
        return self._model.optimization_target

    def split_layer(self, a_layer: BaseHailoLayer, *, auto_clip: bool = True, rank: int = 1) -> Tuple[str, str]:
        # Save original stats and weights

        original_stats = a_layer.export_stats()
        act_op_weights = a_layer.act_op.export_weights()
        clip_range = self._apply_legal_range_clip(a_layer, auto_clip)

        transition_info = TransitionInfo(
            input_scales=a_layer.input_scales,
            output_scales=a_layer.output_scales,
            input_zp=a_layer.input_zero_points,
            output_zp=a_layer.output_zero_points,
            accumulator_scales=a_layer.activation_atomic_op.input_scale,
        )
        b_layer = self._create_successors(a_layer)
        a_layer = self._change_activation(a_layer, auto_clip)

        a_stats, b_stats = self._propagate_stats(original_stats, auto_clip=clip_range)

        a_layer.import_stats(a_stats)
        b_layer.import_stats(b_stats)
        b_layer.add_supported_state(OpStates.CALIBRATED)
        b_layer.import_weights(act_op_weights)

        a_precision_config, b_precision_config = self._propagate_precision(a_layer, b_layer)
        self._add_layer_to_model(a_layer, b_layer)
        self._translate_layers([(a_layer, a_precision_config), (b_layer, b_precision_config)], transition_info)
        return a_layer.full_name, b_layer.full_name

    def _create_successors(self, a_layer: BaseHailoLayer):
        """Creates a new layer withe same Activation as the original layer"""
        scope, layer_name = a_layer.full_name.split("/")
        block_name, layer_name = OptimizationAlgorithm.get_block_and_layer_names(layer_name)
        name = f"{scope}/{block_name}ne_activation_{layer_name}"
        b_layer = HailoStandaloneActivation(name=name, activation=a_layer.activation_atomic_op.act_name)
        return b_layer

    def _change_activation(self, a_layer: BaseHailoLayer, auto_clip: bool):
        """Changes the activation of the original layer to Linear or Clip"""
        if auto_clip:
            min_clip, max_clip = a_layer.act_op.harmless_clipping()
            layer_params = {"clip_min": min_clip, "clip_max": max_clip}
            a_layer.act_op.create_act_name_and_func("clip")

            # this also can be  remove when to_hn works
            a_layer._hn_element["params"]["activation"] = "clip"
            a_layer.act_op.import_weights(layer_params)
        else:
            a_layer.act_op.create_act_name_and_func("linear")
            a_layer._hn_element["params"]["activation"] = "linear"
        return a_layer

    def _propagate_stats(self, original_stats, auto_clip: Optional[Tuple[float, float]] = None):
        """Propagates the stats of the original layer to the new layers"""
        map_a = {"output_0": "preact"}
        map_b = {
            "input_0": "preact",
            "preact": "preact",
            "output_0": "output_0",
        }

        def create_stats(stats, map):
            res = deepcopy(stats)
            for key, val in stats.items():
                inner_key = key.split("/")[1]
                new_key = key.replace(inner_key, map.setdefault(inner_key, inner_key))
                res[key] = stats[new_key]
            return res

        a_stats = create_stats(original_stats, map_a)
        b_stats = create_stats(original_stats, map_b)

        if auto_clip:
            min_clip, max_clip = auto_clip
            min_clip = np.maximum(a_stats["stats/output_0/min"], min_clip)
            max_clip = np.minimum(a_stats["stats/output_0/max"], max_clip)
            # shape should be the same across all stats
            a_stats["stats/output_0/min"] = min_clip
            a_stats["stats/output_0/max"] = max_clip

            b_stats["stats/input_0/min"] = min_clip
            b_stats["stats/input_0/max"] = max_clip

            b_stats["stats/preact/min"] = min_clip
            b_stats["stats/preact/max"] = max_clip
        return a_stats, b_stats

    def _propagate_precision(self, a_layer: BaseHailoLayer, b_layer: BaseHailoLayer):
        """Propagates the precision of the original layer to the new layers"""
        a_precision_config = self._config.precision_config.layers[a_layer.full_name]
        original_trans_config = self._config.translation_config.layers[a_layer.full_name]
        a_precision, b_precision = self.activation_mapping[a_precision_config.precision_mode]

        # Setting precision for a_layer
        a_precision_config.precision_mode = a_precision
        b_precision_config = LayerPrecisionConfig.get_default()
        b_precision_config.precision_mode = b_precision
        b_precision_config.bias_mode = BiasMode.single_scale_decomposition
        b_precision_config.quantization_groups = a_precision_config.quantization_groups

        b_trans_conf = LayerTranslationConfig.get_default()
        b_trans_conf.activation_fit = original_trans_config.activation_fit
        b_trans_conf.force_range_out = original_trans_config.force_range_out
        b_trans_conf.null_channels_cutoff_factor = original_trans_config.null_channels_cutoff_factor
        b_layer.import_translation_config(b_trans_conf)

        a_trans_config = LayerTranslationConfig.get_default()
        a_trans_config.max_elementwise_feed_repeat = original_trans_config.max_elementwise_feed_repeat
        a_trans_config.max_bias_feed_repeat = original_trans_config.max_bias_feed_repeat
        a_trans_config.null_channels_cutoff_factor = original_trans_config.null_channels_cutoff_factor
        a_trans_config.force_range_in = original_trans_config.force_range_in
        a_trans_config.ignore_hw_limitation_assertion = original_trans_config.ignore_hw_limitation_assertion
        a_layer.import_translation_config(a_trans_config)

        self._config.translation_config.layers[b_layer.full_name] = b_trans_conf
        self._config.translation_config.layers[a_layer.full_name] = a_trans_config

        return a_precision_config, b_precision_config

    def _add_layer_to_model(self, a_layer: BaseHailoLayer, b_layer: BaseHailoLayer):
        """Adds the new layers to the model"""
        successors = self._model.flow.successors_sorted(a_layer.full_name)
        edges = [(a_layer.full_name, suc) for suc in successors]
        self._model.add_layer(b_layer, edges)
        shapes = [(None,) + shape for shape in self._model.get_input_shapes()]
        self._model.compute_output_shape(shapes)
        out_shape = a_layer.output_shapes[0]
        b_layer.build(out_shape)

    def _translate_layers(self, layers, transition_info: TransitionInfo):
        """Translates the layers"""
        (a_layer, _), (b_layer, _) = layers

        for index, (scale, zp) in enumerate(zip(transition_info.input_scales, transition_info.input_zp)):
            a_layer.set_input_scale(scale, index)
            a_layer.set_input_zero_point(zp, index)

        for index, (scale, zp) in enumerate(zip(transition_info.output_scales, transition_info.output_zp)):
            b_layer.set_output_scale(scale, index)
            b_layer.set_output_zero_point(zp, index)

        for layer, precision in layers:
            layer: BaseHailoLayer

            layer.import_precision_config(precision, self.optimization_target)
        a_layer.create_output_encoding_candidates()
        b_layer.create_input_encoding_candidates()

        # Fix the scales between the layers

        for index, scale in enumerate(a_layer.output_scales):
            b_layer.set_input_scale(scale, index)

        self.create_hw_params(layers, index=0)
        self.create_hw_params(layers, index=1)

    def create_hw_params(self, layers_info, index):
        layer, precision = layers_info[index]
        layer.enable_lossy()
        clip_cfg = self._config.weights_clipping.layers.setdefault(
            layer.full_name,
            LayerWeightsClippingConfig.get_default(),
        )
        layer.create_hw_params(clip_cfg, self.optimization_target)
        layer.enforce_internal_encoding()
        if index == 0:
            self.update_shift(layer, layers_info[1][0])

        self._config.precision_config.layers[layer.full_name] = precision

    def update_shift(self, layer, layer_b):
        if layer.activation_atomic_op:
            act_op = layer.activation_atomic_op
            assigned_exp = act_op.get_assigned_exponent()
            shift_fix_offset = act_op.get_offset_needed_shift()
            shift_fix_exp = np.max(-assigned_exp)
            shift_fix = np.max([shift_fix_offset, shift_fix_exp])
            if shift_fix > 0:
                layer.update_negative_slope_exponent_shift(shift_fix)
                layer.enforce_internal_encoding()
                layer_b.set_input_scale(layer.output_scale, 0)

    def _apply_legal_range_clip(self, layer: BaseHailoLayer, auto_clip: bool) -> Tuple[float, float]:
        """
        (s_in / s_out) * 2**15 (max_input_value) 2***3(shift) < 2**19 [2**22 for max risky value](Max value before chipping 1)
        ==> s_in * 2**(-1) <s_out = range_out / 2**bits_out
        ==> min_range_out > s_in * 2**(bits_out - 1)
        """
        op = layer.activation_atomic_op

        min_out_range = np.max(op.input_scale * 2 ** (15 - 1))
        min, max = layer.activation_atomic_op.get_input_limvals(0)

        if auto_clip and (max - min > min_out_range):
            min_clip, max_clip = op.harmless_clipping()
            d_l = min_clip - min
            d_r = max - max_clip
            d_clip = max_clip - min_clip
            residual = min_out_range - d_clip
            if residual > 0:
                alpha = d_l / (d_l + d_r)
                min_clip -= residual * alpha
                max_clip += residual * (1 - alpha)
            return min_clip, max_clip

        return None
