"""
Module in charge of searching mix precision configurations
This module implements Implements Mix Precision search an algorithm ...

"""

import logging
from typing import List, Tuple

import pandas as pd
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import MixPrecisionSearchConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ComprecisionMetric, PrecisionMode
from hailo_model_optimization.algorithms.mixed_precision.mix_precision_optimizer import (
    OptimizationSolution,
    SegmentsCreator,
    mix_presion_solver_factory,
)
from hailo_model_optimization.algorithms.mixed_precision.mix_precision_utils import InferResultsManager
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class MixPrecisionSearch(OptimizationAlgorithm):
    """
    This algorithm search for the optimal bit configuration of a network
    based on the effect of noise at the outputs produced by quantization a layer to
    a set bit precision.

    Args:
         hailo_model
         model_config
         unbatched_data_set
         logger_level


    """

    SUPPORTED_MIX_PRECISION = {HailoConv, HailoConvAdd, HailoDense}
    precisions_mode = (PrecisionMode.a16_w16_a16, PrecisionMode.a8_w8_a16, PrecisionMode.a8_w4_a16)
    _config: MixPrecisionSearchConfig

    def __init__(
        self,
        model: HailoModel,
        model_config: ModelOptimizationConfig,
        unbatched_data_set,
        logger_level=logging.INFO,
        **kwargs,
    ):
        super().__init__(model, model_config, "Mix Precision Search", logger_level=logger_level, **kwargs)

        self.place_holder = pd.DataFrame()
        self._config = self.get_algo_config()
        self._unbatched_data_set = unbatched_data_set
        # Save model parameters on self
        self._network_name = model.model_name

    def get_algo_config(self):
        return self._model_config.mix_precision_search

    def finalize_global_cfg(self, algo_config):
        """
        Finalize the algorithm's config. (values that are not layer specific)
        Can include values verification, fetching data from the other algo's config, etc...
        """

    def should_skip_algo(self) -> bool:
        """
        Here we decide whether to skip the algorithm base on the algorithm configuration
        """
        return False

    def _setup(self):
        self.data_feed_cb = self._unbatched_data_set.take(self._config.dataset_size).batch(self._config.batch_size)

    def _run_int(self):
        """This function should run:"""
        analyze_layers = self.create_candidates_layers()
        results_df = self.create_sensitivity_list(analyze_layers)
        self.analyze_results = results_df.copy()
        solutions = self.create_soluions(results_df)
        res = self.attach_presicion_config(solutions, results_df)
        self.place_holder = res
        return res

    ################
    # Main Functions
    ################

    def create_candidates_layers(self) -> List[Tuple[str, BaseHailoLayer]]:
        """Creates a list of layers to run mix precision"""
        analyze_layers = [
            (lname, self._model.layers[lname])
            for lname in self._model.flow.toposort()
            if isinstance(self._model.layers[lname], tuple(self.SUPPORTED_MIX_PRECISION))
        ]
        return analyze_layers

    def create_sensitivity_list(self, analyze_layers: List[Tuple[str, BaseHailoLayer]]) -> pd.DataFrame:
        """
        Analyzing each layer at a time.
        1. Inferring on native model.
        2. For each layer:
            - Set only chosen layer to lossy mode and to a presition mode.
            - Infer lossy model.
            - Get SNR by comparing the native results with the quantized results (on logits and activations).
        """
        results_manager = InferResultsManager()
        model = self._model

        # Set precision mode to Native and Infer
        model.set_native()
        model.enable_internal_encoding()

        self.infer_model_update_results(results_manager)

        for lname, layer in tqdm(analyze_layers, desc="Analising Layers"):
            for precision_mode in self.precisions_mode:
                # Set precision mode and infer
                model.set_native()
                ops = self.set_layer_precision(layer, precision_mode)
                results_manager.set_ops(ops, lname, precision_mode, ComprecisionMetric.BOPS)
                self.infer_model_update_results(results_manager, lname, precision_mode)

        # Calculate SNR and clean the infer results
        results = results_manager.calculate_snr()

        results_manager.clean_results()
        return results

    def create_soluions(self, sensitivity: pd.DataFrame) -> List[OptimizationSolution]:
        """
        Given the sensitivity of all the candidates layers and presitions modes the it will create a solution
        for each of the limits on the configuration
        Args:
            sensitivity: A data frame with the sensitivity information of the model
        """
        sensitivity["snr_val"] = sensitivity["snr_val"].apply(lambda x: min(x, self._config.snr_cap))
        ops_limits = SegmentsCreator(self._logger).load_problem(sensitivity).run(self._config.compresions_markers)
        solutions = []
        for op_limit in ops_limits:
            solver = mix_presion_solver_factory(self._config.optimizer)
            solutions.append(solver(self._logger).load_problem(sensitivity).run(op_limit))
        return solutions

    def attach_presicion_config(
        self,
        solutions: List[OptimizationSolution],
        analysis: pd.DataFrame,
    ) -> List[Tuple[OptimizationSolution, dict]]:
        results = []
        for sol in solutions:
            config = {
                row["layer_name"]: {"precision_mode": PrecisionMode(row["precision"]).reduce()}
                for _, row in sol.solution.iterrows()
            }
            results.append((sol, config))

        return {"results": results, "analysis": analysis}

    # #################
    # Helper Functions
    # #################

    def infer_model_update_results(self, results_namager: InferResultsManager, lname=None, precision_mode=None):
        """Method will be use to infer and update the results on the results manager"""
        # Infer model

        outputs_names = self._model.flow.output_nodes
        import tensorflow as tf

        @tf.function
        def run_model(inputs):
            return self._model(inputs)

        model_outputs = []
        for inputs, _ in self.data_feed_cb:
            outputs = run_model(inputs)

            if len(outputs_names) == 1:
                outputs = [outputs]
            model_outputs.append((outputs, outputs_names))

        # Update results to be single images
        unpacked_model_outputs = []
        for batch_outputs, output_names in model_outputs:
            batch_size = batch_outputs[0].shape[0]

            for i in range(batch_size):
                single_image_outputs = [output[i : i + 1] for output in batch_outputs]
                unpacked_model_outputs.append((single_image_outputs, output_names))

        # process outputs as single images
        for outputs, outputs_names in unpacked_model_outputs:
            for out_layer, tensor in zip(outputs_names, outputs):
                if self._model.layers[out_layer].num_outputs == 1:
                    tensor = [tensor]
                if precision_mode:
                    results_namager.update_precision(
                        tensor,
                        output_name=out_layer,
                        layer_name=lname,
                        precision=precision_mode,
                    )
                else:
                    results_namager.update_native(tensor, output_name=out_layer)

    def set_layer_precision(self, layer: BaseHailoLayer, precision: PrecisionMode):
        config = self._model_config.precision_config.layers[layer.full_name]
        weigh_cliping = self._model_config.weights_clipping.layers[layer.full_name]
        config.precision_mode = precision
        layer.import_precision_config(config, self.optimization_target)
        layer.fully_native = False
        layer.enable_lossy()

        layer.create_io_encoding_candidates()
        layer.enforce_io_encoding()
        layer.create_hw_params(weigh_cliping, self.optimization_target)
        # TODO Layer negative exponent might need to be address on future pr
        layer.enforce_internal_encoding()
        metric = layer.get_bops()
        return metric
