import logging
import os
from functools import partial
from typing import Dict, List, Set

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import (
    ModelOptimizationConfig,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerBiasCorrectionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasCorrectionPolicy, FeaturePolicy
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.algorithms.dali_utils.mock_dali_dataset import cache_list_to_dataset
from hailo_model_optimization.algorithms.layer_by_layer import LayerByLayer


class BiasCorrection(LayerByLayer):
    """
    This module implements the ibc algorithm in a layer by layer manor.

    It uses a single layer each time and it toggles is between quantized and native states.
    LayerByLayer is an abstract class based on OptimizationAlgorithm, which provides basic logic for layer by layer
    iteration
    """

    def __init__(
        self,
        model: HailoModel,
        config_params: ModelOptimizationConfig,
        dataset: tf.data.Dataset,
        work_dir=None,
        logger_level=logging.DEBUG,
        **kwargs,
    ):
        work_dir = work_dir if work_dir is not None else ".bias_correction"
        super().__init__(model, config_params, "Bias Correction", dataset, work_dir, logger_level, **kwargs)
        config_by_layer = self.get_config_by_layer()
        algo_cfg = self.get_algo_config()
        algo_cfg.layers = config_by_layer
        self._layers_info = {f"{layer}/successfully_run": False for layer in self._model.flow.toposort()}

    def export_statistics(self):
        return self._layers_info

    def get_dataset_size(self):
        cfg = self.get_algo_config()
        return cfg.calibset_size

    @property
    def warning_if_larger_dataset(self) -> bool:
        return False

    def get_config_by_layer(self) -> Dict[str, LayerBiasCorrectionConfig]:
        """
        Get LayerBiasCorrectionConfig for each layer in the model
        """
        cfg = self.get_algo_config()
        config_by_layer = {
            layer.full_name: cfg.layers[self.get_layer_name_in_config(layer)]
            for layer in self._model.layers.values()
            if self.is_correctable(layer)
        }
        return config_by_layer

    @staticmethod
    def is_correctable(acceleras_layer):
        """
        Check if a layer's bias can be corrected
        Args:
            acceleras_layer: acceleras layer you want to check

        Returns
            boolean, whether the layer can be corrected or not

        """
        if not isinstance(acceleras_layer, BaseHailoLayer):
            return False
        bias_ops = [op for op in acceleras_layer.atomic_ops if isinstance(op, AddBiasOp)]
        if len(bias_ops) != 1:
            return False
        return bias_ops[0].is_correctable

    def should_correct(
        self,
        layer,
    ) -> bool:
        """
        Check if a layer should be trained or not,
        based on layer support, layer configuration, and algo configuration
        """
        cfg = self.get_algo_config()
        layer_cfg = cfg.layers.get(layer.full_name)

        if not self.is_correctable(layer):
            return False

        if layer_cfg is None:
            layer_cfg = LayerBiasCorrectionConfig()

        policy = self.resolve_policy(cfg.policy, layer_cfg.policy)
        return policy == FeaturePolicy.enabled

    @classmethod
    def resolve_policy(
        cls,
        default_policy: BiasCorrectionPolicy,
        layer_policy: BiasCorrectionPolicy,
    ) -> FeaturePolicy:
        """Resolve the feature policy with per layer policy to get the actual layer policy"""
        if layer_policy == BiasCorrectionPolicy.allowed:
            policy = FeaturePolicy(default_policy.value)
        else:
            policy = FeaturePolicy(layer_policy.value)
        return policy

    def should_skip_algo(self):
        return not self.has_bias_correction()

    def has_bias_correction(self):
        cfg = self.get_algo_config()
        default_policy = cfg.policy
        any_enabled = any(bias_config.policy == BiasCorrectionPolicy.enabled for bias_config in cfg.layers.values())
        any_allowed = any(bias_config.policy == BiasCorrectionPolicy.allowed for bias_config in cfg.layers.values())
        return any_enabled or (any_allowed and default_policy == BiasCorrectionPolicy.enabled)

    def _lbl_setup_logic(self):
        """Init interlayer results for lossy data"""
        self._lossy_interlayer_results = dict()

    def _lbl_pre_layer_logic(self, acceleras_layer: BaseHailoLayer) -> None:
        """
        This logic is triggered before the first inference in lbl _run logic.
        Converts the layer to native for reference native inference.
        """
        acceleras_layer.disable_lossy()
        acceleras_layer.enforce_internal_encoding()

    @staticmethod
    def set_quant_layer(layer: BaseHailoLayer):
        """
        Sets layer in quantized mode
        """
        layer.enable_lossy()
        layer.enforce_internal_encoding()

    def _lbl_post_layer_logic(
        self,
        acceleras_layer: BaseHailoLayer,
        curr_inputs_cache_list: List[str],
        curr_outputs_parent_cache: str,
        inferred_layers: Set[str],
    ):
        """
        This logic is triggered after the lbl logic infers the layer.
        Infer bias in numric mode and fix the bias if needed.
        """
        lname = acceleras_layer.full_name

        self.set_quant_layer(acceleras_layer)
        if self.should_correct(acceleras_layer):
            self.correct_bias(acceleras_layer, curr_outputs_parent_cache)
            self._layers_info[f"{lname}/successfully_run"] = True

        curr_outputs_parent_cache_lossy = self._infer_numeric(acceleras_layer)
        self._lossy_interlayer_results[lname] = curr_outputs_parent_cache_lossy
        self._clean_results(self._lossy_interlayer_results, inferred_layers)

    def get_algo_config(self):
        return self._model_config.bias_correction

    def correct_bias(self, acceleras_layer: BaseHailoLayer, output_reference_dir):
        """
        Infer the layer in numeric mode, compute the correction bias, and assign the upadte.
        """
        if acceleras_layer.num_outputs != 1:
            raise AccelerasImplementationError("Correcting bias with multiple outputs is not defined")
        output_pre_correction_dir = self._infer_numeric(acceleras_layer)
        quant_cache = os.path.join(output_pre_correction_dir, self.get_cache_basename(0))
        ref_cache = os.path.join(output_reference_dir, self.get_cache_basename(0))
        count = self.get_dataset_size()
        dataset, _ = cache_list_to_dataset([quant_cache, ref_cache], count=count)
        diff_dataset = dataset.map(lambda x: x[0] - x[1])
        axes_to_reduce = self.get_bias_reduce_axes(acceleras_layer, diff_dataset.element_spec.shape)
        diff_per_channel = self._compute_mean_diff_per_channel(diff_dataset, axes_to_reduce)
        acceleras_layer.bias.assign(acceleras_layer.bias.numpy() - diff_per_channel)
        acceleras_layer.enforce_internal_encoding()
        self._delete_cache(output_pre_correction_dir)

    def _infer_numeric(self, layer):
        """
        This function fetches the numeric input dataset of the layer and calls it.
        """
        inputs_cache_list_lossy = self._get_layer_inputs_cache_list(
            layer.full_name,
            self._lossy_interlayer_results,
        )
        config = self.get_algo_config()
        count = config.calibset_size
        return self._infer_layer(
            layer,
            inputs_cache_list_lossy,
            count=count,
        )

    @classmethod
    def get_bias_reduce_axes(cls, acceleras_layer: BaseHailoLayer, data_shape):
        if acceleras_layer.transpose_width_features:
            features_axis = -2
        else:
            features_axis = -1
        ndim = np.arange(len(data_shape))
        return np.delete(ndim, features_axis)

    def _compute_mean_diff_per_channel(self, diff_dataset, axes_to_reduce):
        """
        Compute the correction bias per channel.
        dataset should be zipped dataset with native and numeric results of the layer
        """
        mean_func = partial(tf.reduce_mean, axis=axes_to_reduce)
        diff_per_channel_dataset = diff_dataset.map(mean_func)
        diff_per_channel_np = np.array([i.numpy() for i in diff_per_channel_dataset])
        diff_per_channel = diff_per_channel_np.mean(axis=0)
        return diff_per_channel
