import logging
from typing import Dict

import numpy as np
import tensorflow as tf
from numpy.typing import ArrayLike
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer, HailoOutputLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import CheckerConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeaturePolicy
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasUnsupportedError
from hailo_model_optimization.acceleras.utils.flow_state.updater import modify_flow_state
from hailo_model_optimization.algorithms.lat_utils.lat_model import LATModel
from hailo_model_optimization.algorithms.lat_utils.lat_noise_metrics import SNRNoise
from hailo_model_optimization.algorithms.lat_utils.lat_utils import AnalysisMode
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class HailoQuantAnalyzer(OptimizationAlgorithm):
    """
    Algorithm for analyzing a network. The algorithm will run the model in quant mode and in native mode and
    calculate the SNR at the investigated layer and in the logits.

    Args:
        model - an acceleras model.
        model_config - a parsed alls file.
        unbatched_data_set - tf unbatched dataset object.
        work_dir - A directory to save the results. The final dir will be the work_dir concatenated with the date and
                   time.

    """

    DEFAULT_TENSOR_DTYPE = np.float32
    DEFAULT_SAMPLES_MAX_COUNT = 1000
    DEFAULT_HISTOGRAM_BINS_COUNT = 512
    MAX_DATASET_LENGTH = 1024

    def __init__(
        self,
        model: HailoModel,
        model_config: ModelOptimizationConfig,
        unbatched_data_set,
        native_model: HailoModel = None,
        **kwargs,
    ):
        super().__init__(model, model_config, "Layer Noise Analysis", logger_level=logging.INFO, **kwargs)
        self._logger.debug("Initializing Layer Noise Analysis Tool")

        if native_model is None:
            native_model = model
        self._native_model = native_model

        self._config = self.get_algo_config()

        # Save model parameters on self
        self._network_name = model.model_name

        # Initialize test parameters
        self.analyze_layers = [
            lname
            for lname, layer in self._model.layers.items()
            if not isinstance(layer, (HailoInputLayer, HailoOutputLayer, BaseHailoNonNNCoreLayer))
        ]
        self.analysis_results = None
        self._statistics = None
        self._full_quant_net_statistics = None

        self.analyze_mode = self._config.analyze_mode

        # Setup variables
        self._unbatched_data_set = unbatched_data_set
        self._batch_size = self._config.batch_size

    def _run_int(self):
        self._logger.debug("Starting Layer Noise Analysis Tool")

        # Analyze model
        try:
            if self.analyze_mode == AnalysisMode.simple:
                self.analyze_full_quant_net()
            elif self.analyze_mode == AnalysisMode.advanced:
                self.analyze_full_quant_net()
                self.analyze_layer_by_layer()
            else:
                raise AccelerasUnsupportedError(f"Analyze mode {self.analyze_mode.full_name} is not supported.")
        except tf.errors.ResourceExhaustedError:
            self._logger.warning(
                "GPU memory has been exhausted. Layer Noise Analysis will not generate statistics. Try either:"
                + "1) Lower batch size using the command: model_optimization_config(checker_cfg, batch_size=1). "
                + "2) Disable the algorithm using the command: model_optimization_config(checker_cfg, policy=disabled). "
                + "3) Force using CPU by setting the CUDA_VISIBLE_DEVICES environment variable to non-exsits GPU",
            )
            return
        except KeyboardInterrupt:
            self._logger.warning("Layer Noise Analysis Tool cut by the user, statistics params may not be generated.")

        self._statistics = self._create_statistics()
        self.analysis_results = self._create_analysis_data()

    def analyze_layer_by_layer(self):
        """
        Analyzing each layer at a time.
        1. Inferring on native model.
        2. For each layer:
            - Set only chosen layer to lossy mode.
            - Infer lossy model.
            - Get SNR by comparing the native results with the quantized results (on logits).
        """
        self._logger.debug(
            f"Starting Layer by Layer Analysis, running on {self._num_of_images} pics with batch size "
            f"{self._batch_size}",
        )

        native_sample = dict()
        numeric_sample = dict()

        model_state = self._model.export_flow_state()

        self._native_model.set_native()
        self._model.set_native()

        pbar = tqdm(
            total=len(self.analyze_layers) * self._num_of_images // self._batch_size,
            dynamic_ncols=True,
            unit="iterations",
            desc="Layer-by-Layer Analysis",
        )

        for inputs, _ in self.data_feed_cb:
            # Get all native results for analayzed layers
            outputs = self._native_model(inputs)
            if self._num_of_outputs == 1:
                outputs = [outputs]
            for out_layer, tensor in zip(self._output_layers_names, outputs):
                if self._native_model.layers[out_layer].num_outputs == 1:
                    tensor = [tensor]
                native_sample[out_layer] = tensor

            for lname in self.analyze_layers:
                pbar.update(1)
                # Initializations
                self._model.layers[lname].fully_native = False
                self._model.layers[lname].enable_lossy()

                # Infer model
                outputs = self._model(inputs)
                if self._num_of_outputs == 1:
                    outputs = [outputs]
                for out_layer, tensor in zip(self._output_layers_names, outputs):
                    if self._model.layers[out_layer].num_outputs == 1:
                        tensor = [tensor]
                    numeric_sample[out_layer] = tensor

                # Calc noise
                for out_layer in self._output_layers_names:
                    self.noise_results[out_layer].update(native_sample[out_layer], numeric_sample[out_layer], lname)

                # Return to native
                self._model.layers[lname].fully_native = True
        pbar.refresh()  # flush the fine pbar state
        pbar.close()

        self._model.import_flow_state(model_state)

    def _combine_lat_models_results(self, results):
        combine_results = {}
        keys = {k.split(":")[0] for k in results.keys()}
        for key in keys:
            num_outputs = self._model.layers["/".join(key.split("/", 2)[:2])].num_outputs
            combine_results[key] = np.array([results[f"{key}:{i}"] for i in range(num_outputs)])
        return combine_results

    def analyze_full_quant_net(self):
        """
        Single inference of the model.
        1. Infer native model.
        2. Set all layers to lossy mode.
        3. Infer lossy model.
        4. Calculate SNR on each of the chosen layers (Only activations SNR are being calculated).
        """
        self._logger.debug(
            f"Starting Analyze full quant Analysis, running on {self._num_of_images} pics with batch "
            f"size {self._batch_size}",
        )

        shapes = [(self._batch_size,) + shape for shape in self._model.get_input_shapes()]
        self._model.compute_output_shape(shapes)  # make sure each layer has output_shapes
        self._model.set_lossy()

        pbar = tqdm(
            total=self._num_of_images // self._batch_size,
            dynamic_ncols=True,
            unit="iterations",
            desc="Full Quant Analysis",
        )

        with modify_flow_state(
            self._model,
            self._config.custom_infer_config,
            self._logger,
        ) as model:
            lat_model = LATModel(model, self._native_model)

            # LATModel has non-XLA-compatible metrics, hence jit_compile=False
            @tf.function(jit_compile=False, reduce_retracing=True)
            def predict_function(data):
                return lat_model(data, training=False)

            lat_model.build(self.data_feed_cb)
            lat_model.compile()

            for inputs, _ in self.data_feed_cb:
                predict_function(inputs)
                pbar.update(1)

            pbar.refresh()  # flush the fine pbar state
            pbar.close()
            metrics_result = lat_model.metrics_result()
        self._full_quant_net_statistics = self._combine_lat_models_results(metrics_result)

    def _create_analysis_data(self):
        noise_results = {}
        activations_histogram = {
            "native": {},
            "numeric": {},
        }
        sampled_tensors = {}
        snr_per_layer = {}

        if self.analyze_mode == AnalysisMode.advanced:
            for out_layer in self._output_layers_names:
                noise_results[out_layer] = {}
                for layer in self.analyze_layers:
                    if layer + "/noise_results/" + out_layer in self._statistics:
                        noise_results[out_layer][layer] = self._statistics[layer + "/noise_results/" + out_layer][
                            0
                        ].tolist()

        for layer in self.analyze_layers:
            if layer + "/sample/native" in self._statistics:
                activations_histogram["native"][layer] = {
                    "hist": self._statistics[layer + "/histogram/hist"][0].tolist(),
                    "bin_edges": self._statistics[layer + "/histogram/bin_edges"][0].tolist(),
                }

                sampled_tensors[layer] = {
                    "native": self._statistics[layer + "/sample/native"][0].tolist(),
                    "numeric": self._statistics[layer + "/sample/numeric"][0].tolist(),
                    "channel": self._statistics[layer + "/sample/channel"][0].tolist(),
                }

            if layer + "/snr" in self._statistics:
                snr_per_layer[layer] = self._statistics[layer + "/snr"][0].tolist()

        data = {
            "activations_histogram": activations_histogram,
            "snr_per_layer": snr_per_layer,
            "sampled_tensors": sampled_tensors,
            "noise_results": noise_results,
        }
        return data

    def get_algo_config(self):
        """
        return the current algorithm configuration

        """
        return self._model_config.checker_cfg

    def _setup(self):
        """
        Validates the inputs, check if the data is ready to be use
        """
        self.analyze_mode = self._config.analyze_mode
        # Init data
        if self._config.dataset_size is None:
            self._num_of_images = (
                self._unbatched_data_set.take(self.MAX_DATASET_LENGTH).reduce(0, lambda x, _: x + 1).numpy()
            )
        else:
            self._num_of_images = self._config.dataset_size
        self._batch_size = (
            self._config.batch_size
            if self._config.batch_size is not None
            else self._model_config.calibration.batch_size
        )
        if self._num_of_images % self._batch_size != 0:
            new_num_of_images = self._num_of_images - self._num_of_images % self._batch_size
            self._logger.info(
                f"Using only {new_num_of_images} images for calibration instead of {self._num_of_images}, "
                f"because {self._num_of_images} is not an integer product of the batch size "
                f"{self._batch_size}.",
            )
            self._num_of_images = new_num_of_images
        self.data_feed_cb = self._unbatched_data_set.take(self._num_of_images).batch(self._batch_size)

        # Results variables
        self._output_layers_names = self._model.flow.output_nodes
        self._num_of_outputs = len(self._output_layers_names)

        self.noise_results = {}
        if self.analyze_mode == AnalysisMode.advanced:
            for out_layer in self._output_layers_names:
                self.noise_results[out_layer] = SNRNoise(self.analyze_layers)

    def should_skip_algo(self):
        """
        If user has run the LAT algorithm, there is no reason to skip it.
        """
        return self._config.policy == FeaturePolicy.disabled

    def log_config(self):
        """
        The log will print the necessary thing from the _int_run when needed.
        """

    def _create_statistics(self) -> Dict[str, ArrayLike]:
        """
        Format value such that result will be of type Dict[str, ArrayLike], and each key will start with layer name.
        """
        params_statistics = dict()

        params_statistics.update(self._full_quant_net_statistics)

        if self.analyze_mode == AnalysisMode.advanced:
            for out_layer in self._output_layers_names:
                noise_results = self.noise_results[out_layer].get()
                for layer, value in noise_results.items():
                    params_statistics[layer + "/noise_results/" + out_layer] = value

        return params_statistics

    def export_statistics(self):
        """
        Return analysis_results as statistics dictionary.
        """
        return self._statistics

    def finalize_global_cfg(self, algo_config: CheckerConfig):
        # Is there an better way to get dataset length?
        if algo_config.dataset_size is None:
            algo_config.dataset_size = self._model_config.calibration.calibset_size
        if algo_config.batch_size is None:
            algo_config.batch_size = self._model_config.calibration.batch_size
        self.check_dataset_length(algo_config, "dataset_size", self._unbatched_data_set)
        self.check_batch_size(algo_config, "dataset_size", "batch_size")
