"""
This module implements bias correction algorithm using naive block heuristic when dividing the model
"""

import logging
from functools import partial
from typing import Dict

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_acceleras_layer import BaseAccelerasLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_quant_weight_group import HailoConvQuantWeightGroup
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerBiasCorrectionConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import GlobalBiasCorrectionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasCorrectionPolicy, FeaturePolicy
from hailo_model_optimization.algorithms.block_by_block.block_by_block import BlockByBlock
from hailo_model_optimization.algorithms.block_by_block.block_heuristic import naive_blocks
from hailo_model_optimization.algorithms.block_by_block.cache_utils import dataset_from_cache
from hailo_model_optimization.algorithms.dali_utils.dataset_util import tf_unpad_input


class BiasCorrection(BlockByBlock):
    """
    Variant for the bias correction algorithm which uses blocks iteration
    The bias correction algorithm computes the mean difference per channel for each layer
    and apply the diff to the bias of the layer
    """

    def __init__(
        self,
        model: HailoModel,
        fp_model: HailoModel,
        model_config: ModelOptimizationConfig,
        dataset: tf.data.Dataset,
        *args,
        work_dir=None,
        logger_level=logging.DEBUG,
        logger=None,
        **kwargs,
    ):
        super().__init__(
            model,
            fp_model,
            model_config,
            "Bias Correction",
            dataset,
            *args,
            work_dir=work_dir,
            logger_level=logger_level,
            logger=logger,
            **kwargs,
        )
        self._layers_info = {f"{layer}/successfully_run": False for layer in self._model.flow.toposort()}

    def _setup(self):
        config_by_layer = self.get_config_by_layer()
        algo_cfg = self.get_algo_config()
        algo_cfg.layers = config_by_layer

        self.check_storage_usage(dali_cache=False)
        self._logger.info(f"Using dataset with {algo_cfg.calibset_size} entries for {self._name}")

    def get_dataset_size(self):
        return self.get_algo_config().calibset_size

    def get_batch_size(self):
        return self.get_algo_config().batch_size

    def is_compressed_cache(self) -> bool:
        return self.get_algo_config().cache_compression == FeaturePolicy.enabled

    def export_statistics(self):
        return self._layers_info

    @staticmethod
    def is_correctable(acceleras_layer: BaseAccelerasLayer, fp_model) -> bool:
        """
        Check if a layer's bias can be corrected
        Args:
            acceleras_layer: acceleras layer you want to check

        Returns
            boolean, whether the layer can be corrected or not

        """
        if not isinstance(acceleras_layer, BaseHailoLayer):
            return False
        if acceleras_layer.full_name not in fp_model.layers:
            return False
        bias_ops = list(acceleras_layer.get_bias_ops())
        if len(bias_ops) != 1:
            return False
        if acceleras_layer.num_outputs != 1:
            return False
        return bias_ops[0].is_correctable

    @classmethod
    def resolve_policy(
        cls,
        default_policy: BiasCorrectionPolicy,
        layer_policy: BiasCorrectionPolicy,
    ) -> FeaturePolicy:
        """Resolve the feature policy with per layer policy to get the actual layer policy"""
        if layer_policy == BiasCorrectionPolicy.allowed:
            policy = FeaturePolicy(default_policy.value)
        else:
            policy = FeaturePolicy(layer_policy.value)
        return policy

    def should_correct(
        self,
        layer,
    ) -> bool:
        """
        Check if a layer should be trained or not,
        based on layer support, layer configuration, and algo configuration
        """
        cfg = self.get_algo_config()
        layer_cfg = cfg.layers.get(layer.full_name)

        if not self.is_correctable(layer, self._fp_model):
            return False

        if layer_cfg is None:
            layer_cfg = LayerBiasCorrectionConfig()

        policy = self.resolve_policy(cfg.policy, layer_cfg.policy)
        return policy == FeaturePolicy.enabled

    def get_blocks(self) -> Dict[str, ModelFlow]:
        return naive_blocks(self._model, partial(self.is_correctable, fp_model=self._fp_model))

    def get_algo_config(self):
        return self._model_config.bias_correction

    def should_skip_algo(self):
        return not self.has_bias_correction()

    def has_bias_correction(self):
        """
        Check if the model has bias correction enabled or any layer with bias correction enabled
        """
        cfg = self.get_algo_config()
        default_policy = cfg.policy
        any_enabled = any(bias_config.policy == BiasCorrectionPolicy.enabled for bias_config in cfg.layers.values())
        return any_enabled or (default_policy == BiasCorrectionPolicy.enabled)

    def get_config_by_layer(self) -> Dict[str, LayerBiasCorrectionConfig]:
        """
        Get LayerBiasCorrectionConfig for each layer in the model
        """
        layers_config = self.get_algo_config().layers
        cfg_by_layer = dict()
        for layer in self._model.layers.values():
            if not self.is_correctable(layer, self._fp_model):
                continue
            cfg_lname = self.get_layer_name_in_config(layer)
            cfg_by_layer[layer.full_name] = layers_config.get(cfg_lname, LayerBiasCorrectionConfig.get_default())
        return cfg_by_layer

    def _run_int(self):
        try:
            self._comperative_run(pre_quant_cb=self.core_logic)
        except KeyboardInterrupt:
            self._logger.warning("Bias correction has been terminated by the user, proceed at your own peril")

    def core_logic(self, block_model: HailoModel, native_results, quant_input):
        """
        Correct the bias of the output layer of the block.
        This function also converts the block to lossy mode
        """
        corrected_lname = block_model.flow.predecessors_sorted(block_model.flow.output_nodes[0])[0]
        corrected_layer = block_model.layers[corrected_lname]

        if not self.should_correct(corrected_layer):
            return

        self.correct_bias_from_cache(
            block_model,
            corrected_layer,
            native_results,
            quant_input,
            self.get_dataset_size(),
            self.get_batch_size(),
        )

        self._layers_info[f"{corrected_layer.full_name}/successfully_run"] = True

    @classmethod
    def correct_bias_from_cache(
        cls,
        block_model: HailoModel,
        target_layer: BaseAccelerasLayer,
        native_data: Dict[str, str],
        quant_data: Dict[str, str],
        data_count: int,
        batch_size: int,
    ):
        block_outputs = block_model.flow.output_nodes
        block_inputs = block_model.flow.input_nodes

        native_data, result_shape = dataset_from_cache(block_outputs, native_data, data_count)
        quant_data, inputs_shape = dataset_from_cache(block_inputs, quant_data, data_count)

        axes = cls._get_bias_reduce_axes(target_layer)

        @tf.function
        def get_mean_diff(input_batch, result_batch):
            native_result = tf_unpad_input(result_batch, result_shape)
            inp = tf_unpad_input(input_batch, inputs_shape)
            quant_result = block_model(inp)
            diff = quant_result - native_result
            mean_diff = tf.reduce_mean(diff, axis=axes)
            mean_diff = tf.cast(mean_diff, tf.float64)
            return tf.reduce_sum(mean_diff, axis=0)

        total = 0
        dataset = zip(quant_data.batch(batch_size), native_data.batch(batch_size))
        diff_sum = 0
        for quant_input_batch, native_result_batch in dataset:
            diff = get_mean_diff(quant_input_batch, native_result_batch)
            bs = quant_input_batch.shape[0]
            diff_sum += diff
            total += bs
        mean_diff_per_channel = tf.cast(diff_sum / total, tf.float32)
        _bias = target_layer.bias.numpy()
        mean_diff_per_channel = (
            mean_diff_per_channel.numpy()
        )  # The bias correction is comparing native layer ouput to quantizned layer output which has
        # `out_channel` channels. The accumlator have `1out_channel * quantization_weight_groups` channels.
        # After it we will sum the channels such that the bias is added to each out_channel once.
        if isinstance(target_layer, HailoConvQuantWeightGroup):
            _modified_part = _bias[: mean_diff_per_channel.shape[0]] - mean_diff_per_channel
            assigned_bias = np.concatenate((_modified_part, _bias[mean_diff_per_channel.shape[0] :]))
        else:
            assigned_bias = _bias - mean_diff_per_channel
        target_layer.bias.assign(assigned_bias)
        target_layer.enforce_internal_encoding()

    @staticmethod
    def _get_bias_reduce_axes(acceleras_layer: BaseHailoLayer):
        """
        Get the reduced axes of the layer (the features channel might differ)
        """
        features_axis = -2 if (acceleras_layer.transpose_width_features) else -1
        ndim = np.arange(len(acceleras_layer.output_shape[1:])) + 1
        return np.delete(ndim, features_axis)

    def finalize_global_cfg(self, algo_config: GlobalBiasCorrectionConfig):
        # Is there an better way to get dataset length?
        if algo_config.calibset_size is None:
            algo_config.calibset_size = self._model_config.calibration.calibset_size
        if algo_config.batch_size is None:
            algo_config.batch_size = self._model_config.calibration.batch_size
        if not self.should_skip_algo():
            self.check_dataset_length(algo_config, "calibset_size", self._dataset)
            self.check_batch_size(algo_config, "calibset_size", "batch_size")

    def _get_valid_layer_cfg(self, lname, cfg: LayerBiasCorrectionConfig) -> LayerBiasCorrectionConfig:
        if not self.is_correctable(self._model.layers[lname], self._fp_model):
            cfg = {}
        return cfg
