import networkx as nx
import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_nms import HailoNMS
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    FeaturePolicy,
    IgnoreHwLimitationAssertionPolicy,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class QuantChecker(OptimizationAlgorithm):
    """
    check the statistics inside params_statistics append them with warnings based on the quantization presses.
    """

    # Warning: this algorithm assumes the translation_config has already been handled (by create_io_encoding)

    BN_THRESHOLD = 10
    BN_EPSILON = 1e-3

    def __init__(self, model, model_config, logger_level: int, params_statistics, logger=None):
        super().__init__(
            model=model,
            model_config=model_config,
            name="Quantization Checker",
            logger_level=logger_level,
            logger=logger,
        )
        self._params_statistics = params_statistics
        self._generated_statistics = dict()
        self._config = self.get_algo_config()

    def get_algo_config(self):
        return self._model_config.checker_cfg

    def log_config(self):
        pass

    def _setup(self):
        pass

    def _run_int(self):
        if self._config.batch_norm_checker:
            self.check_bn()
        self._log_info()

    def _get_first_layer_with_bn(self, bn_statistics):
        layers = list(set(["/".join(k.split("/")[:2]) for k in bn_statistics.keys()]))
        depths = dict(nx.all_pairs_shortest_path_length(self._model.flow))
        depths = [depths.get(input_layer, {}) for input_layer in self._model.flow.input_nodes]

        def sort_layers(layer):
            return np.min([depth.get(layer, np.inf) for depth in depths])

        return sorted(layers, key=sort_layers)[0]

    def check_bn(self):
        normalization_optimizer_str = "normalization_optimizer"
        bn_statistics = {k: v for k, v in self._params_statistics.items() if normalization_optimizer_str in k}
        if len(bn_statistics) == 0:
            return
        first_bn_layer = self._get_first_layer_with_bn(bn_statistics)
        if self._model.layers[first_bn_layer].hn_element["params"].get("pre_layer_batch_norm", False):
            # skipping batch norm stats check for layers that contain a forward-folded batch norm
            return

        layer_translation_cfg = self._model_config.translation_config.layers[first_bn_layer]
        if layer_translation_cfg.ignore_hw_limitation_assertion == IgnoreHwLimitationAssertionPolicy.enabled:
            return

        layer_statistics = {
            k.split("/")[-1]: v
            for k, v in bn_statistics.items()
            if k.startswith(first_bn_layer + f"/{normalization_optimizer_str}/")
        }
        preact_stats = self._model.layers[first_bn_layer].get_preact_stats()
        if len(preact_stats) != 1:
            return
        preact_stats = preact_stats[0]

        mean = preact_stats.mean
        std = preact_stats.energy - mean**2
        expected_mean = layer_statistics["beta:0"]
        expected_std = layer_statistics["gamma:0"] ** 2
        expected_std *= layer_statistics["moving_variance:0"] / (
            layer_statistics["moving_variance:0"] + layer_statistics["epsilon:0"]
        )

        std = np.maximum(std, self.BN_EPSILON)
        expected_std = np.maximum(expected_std, self.BN_EPSILON)

        # calculate Kullback Leibler divergence between the two distributions
        kl = np.max(np.log(expected_std / std) + (std + (expected_mean - mean) ** 2) / expected_std - 1) / 2
        kl2 = np.max(np.log(std / expected_std) + (expected_std + (mean - expected_mean) ** 2) / std - 1) / 2

        if kl > self.BN_THRESHOLD or kl2 > self.BN_THRESHOLD:
            self._logger.warning(
                f"The measured distribution for layer {first_bn_layer} is different than expected. "
                f"This could happened when the calibration set isn't normalized.\n"
                f"To disable this warning add the following line to the model script:\n"
                f"model_optimization_config(checker_cfg, batch_norm_checker=False)",
            )
            self._generated_statistics[first_bn_layer + "/check_bn/failed"] = True
            self._generated_statistics[first_bn_layer + "/check_bn/kl:0"] = kl
            self._generated_statistics[first_bn_layer + "/check_bn/kl2:0"] = kl2
            self._generated_statistics[first_bn_layer + "/check_bn/mean:0"] = mean
            self._generated_statistics[first_bn_layer + "/check_bn/expected_mean:0"] = expected_mean
            self._generated_statistics[first_bn_layer + "/check_bn/std:0"] = std
            self._generated_statistics[first_bn_layer + "/check_bn/expected_std:0"] = expected_std
        else:
            self._generated_statistics[first_bn_layer + "/check_bn/failed"] = False

    def _log_snr_info(self):
        output_layers = []
        for outp in self._model.flow.output_nodes:
            predessesor = self._model.flow.predecessors_sorted(outp)[0]
            if isinstance(self._model.layers[predessesor], HailoNMS):
                # If we have NMS layer before the output, we want to log the SNR of the NMS input.
                output_layers.append(self._model.flow.predecessors_sorted(predessesor)[0])
            else:
                output_layers.append(outp)
        # log output SNR
        if all(f"{lname}/layer_noise_analysis/snr" in self._params_statistics.keys() for lname in output_layers):
            self._logger.info(
                "Output layers signal-to-noise ratio (SNR): measures the quantization noise (higher is better)",
            )
            for output_layer in output_layers:
                snr_key = f"{output_layer}/layer_noise_analysis/snr"
                self._logger.info(f"\t{output_layer} SNR:\t{self._params_statistics[snr_key][0].tolist():.4} dB")

    def _log_normalization_recommendation(self, threshold=32):
        # log normalization recommendation
        MAX_OUTPUT_CHANNELS = 4
        for lname in self._model.flow.input_nodes:
            layer = self._model.layers[lname]
            input_lossy_element = layer.get_input_lossy_elements()[0]
            if (
                layer.output_shape[-1] > MAX_OUTPUT_CHANNELS
                or input_lossy_element.bits != 8
                or input_lossy_element.signed
            ):
                continue
            stats = layer.get_output_stats()[0]

            normalization_in_net = (
                self._model_config.translation_config.layers[lname].input_normalization == FeaturePolicy.enabled
            )
            need_quantization = np.any(layer.input_scale != 1.0) or np.any(layer.input_zero_point != 0)
            quantization_close = np.allclose(
                layer.input_scale, 1.0, rtol=0, atol=threshold / 2**input_lossy_element.bits
            ) and np.allclose(layer.input_zero_point, 0, rtol=0, atol=threshold)
            in_range = np.allclose(stats.min, input_lossy_element.min_value, rtol=0, atol=threshold) and np.allclose(
                stats.max, input_lossy_element.max_value, rtol=0, atol=threshold
            )

            calibration_range = list(zip(stats.min, stats.max))
            desired_range = [input_lossy_element.min_value, input_lossy_element.max_value]
            if not normalization_in_net and need_quantization and quantization_close:
                # case 1: User didn't use normalization_in_net but scale/zero_point are close to 1/0 (but not exact)
                script_command = f"quantization_param({lname}, force_range_out={desired_range})"
                reference = "Hailo Dataflow Compiler user guide / Model Optimization / Optimization Related Model Script Commands / quantization_param / force_range_out"
                self._logger.info(
                    f"The calibration set indicates that the neural core receives values of range {calibration_range}.\n"
                    f"Consider forcing the range (using the {script_command} model script command) to be exactly {desired_range} to save CPU utilization on runtime.\n"
                    f"Refer to the user guide tutorial {reference} for details."
                )
            elif not normalization_in_net and need_quantization and not quantization_close:
                # case 2: User didn't use normalization_in_net and scale/zero_point are not close to 1/0
                reference = "Hailo Dataflow Compiler user guide / Model Optimization / Optimization Related Model Script Commands / model_modification_commands / normalization"
                self._logger.info(
                    f"The calibration set seems to not be normalized, because the values range is {calibration_range}.\n"
                    f"Since the neural core works in 8-bit (between {desired_range[0]} to {desired_range[1]}), a quantization will occur on the CPU of the runtime platform.\n"
                    f"Add a normalization layer to the model to offload the normalization to the neural core.\n"
                    f"Refer to the user guide {reference} for details."
                )
            elif normalization_in_net and not in_range:
                # case 3: User used normalization_in_net but the range is not [0,255]
                self._logger.warning(
                    f"The expected calibration set should be {desired_range} when using an in-net normalization layer, but the range received is {calibration_range}."
                )

    def _log_info(self):
        self._log_snr_info()
        self._log_normalization_recommendation()

    def should_skip_algo(self):
        return False

    def export_statistics(self):
        return self._generated_statistics

    def finalize_global_cfg(self, algo_config):
        pass
