"""
This module implements the AdaRound algorithm for an acceleras model
https://arxiv.org/abs/2004.10568
"""

import logging
import shutil
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_quant_weight_group import HailoConvQuantWeightGroup
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerAdaRoundConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import AdaRoundConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    AdaRoundMode,
    FeaturePolicy,
    LayerFeaturePolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasResourceError, AdaRoundError
from hailo_model_optimization.algorithms.ada_round.adaround_training_model import AdaRoundTrainModel
from hailo_model_optimization.algorithms.ada_round.loss_utils import AdaroundV2Callbacks, l2_loss
from hailo_model_optimization.algorithms.bias_correction.bias_correction_v2 import BiasCorrection
from hailo_model_optimization.algorithms.block_by_block.block_by_block import BlockByBlock
from hailo_model_optimization.algorithms.block_by_block.block_heuristic import naive_blocks
from hailo_model_optimization.algorithms.block_by_block.cache_utils import get_cache_list, get_layer_cache_dir
from hailo_model_optimization.algorithms.dali_utils import dali_train_dataset


@dataclass
class ResolvedLayerTrainConfig:
    train_bias: bool
    epochs: int
    warmup: float
    weight: float
    b_range: Tuple[int, int]
    decay_start: float
    dataset_size: int
    batch_size: int


class AdaRound(BlockByBlock):
    """
    This class implements adaround algorithm for an acceleras model

    Args:
        model: Acceleras model to apply the algo to
        config_params: Model config
        dataset: data for the layers' train
        work_dir: path for files cache

    """

    # TODO: fix support for deconv layer
    TRAINABLE_LAYERS = {HailoConv, HailoDense, HailoConvAdd, HailoDepthwise, HailoConvQuantWeightGroup}

    def __init__(
        self,
        model: HailoModel,
        fp_model: HailoModel,
        model_config: ModelOptimizationConfig,
        dataset: tf.data.Dataset,
        *args,
        is_disable_internal_encoding: bool = True,
        work_dir: Optional[str] = None,
        logger_level=logging.DEBUG,
        logger=None,
        **kwargs,
    ):
        # TODO: add temperature, maybe control gamma and zeta, learning config
        super().__init__(
            model,
            fp_model,
            model_config,
            "Adaround",
            dataset,
            *args,
            is_disable_internal_encoding=is_disable_internal_encoding,
            work_dir=work_dir,
            logger_level=logger_level,
            logger=logger,
            **kwargs,
        )
        self.compiled_loss = None
        self._layers_info = {f"{layer}/successfully_run": False for layer in self._model.flow.toposort()}

    def _setup(self):
        layers = {layer.full_name: self.get_layer_name_in_config(layer) for layer in self._model.layers.values()}
        algo_cfg = self.get_algo_config()
        config_by_layer = self.get_config_by_layer(algo_cfg, layers)
        algo_cfg.layers = config_by_layer
        self.check_storage_usage(dali_cache=True)
        self._logger.info(f"Using dataset with {algo_cfg.dataset_size} entries for {self._name}")
        self._logger.info(f"Using dataset with {algo_cfg.bias_correction_count} entries for bias correction")

    def export_statistics(self):
        return self._layers_info

    def get_dataset_size(self):
        return self.get_algo_config().dataset_size

    def get_batch_size(self):
        return self.get_algo_config().batch_size

    def is_compressed_cache(self) -> bool:
        return self.get_algo_config().cache_compression == FeaturePolicy.enabled

    def get_algo_config(self):
        """Return the current algorithm configuration"""
        return self._model_config.adaround

    def should_skip_algo(self):
        """
        Checks if the algorithm is enabled to any layer
        Return:
            boolean value that indicates if adaround is applied to at least one layer
        """
        for layer in self._model.layers.values():
            should_train = self.should_train(layer)
            if should_train:
                return False
        return True

    def get_blocks(self) -> Dict[str, ModelFlow]:
        def filter_func(layer):
            return self.should_train(layer) or BiasCorrection.is_correctable(layer, self._fp_model)

        return naive_blocks(self._model, filter_func)

    def _run_int(self):
        try:
            self._comperative_run(pre_quant_cb=self.core_logic)
        except KeyboardInterrupt:
            self._logger.warning("Adaround has been terminated by the user, proceed at your own peril")

    def core_logic(
        self,
        block_model: HailoModel,
        native_results: Dict[str, str],
        quant_results: Dict[str, str],
    ):
        """
        Core logic for block in the adaround algorithm.
        apply adaptive rounding if the output layer can be trained
        apply bias correction if it has bias but cannot be trained
        otherwise do nothing
        """
        cfg = self.get_algo_config()
        target_lname = block_model.flow.predecessors_sorted(block_model.flow.output_nodes[0])[0]
        target_layer = block_model.layers[target_lname]

        is_trainable = self.should_train(target_layer)
        is_correctable = BiasCorrection.is_correctable(target_layer, self._fp_model)
        if is_trainable:
            self.train_layer(block_model, target_lname, native_results, quant_results)
            self._layers_info[f"{target_lname}/train_type"] = "adaround"
        elif cfg.train_bias and is_correctable:
            BiasCorrection.correct_bias_from_cache(
                block_model,
                target_layer,
                native_results,
                quant_results,
                cfg.bias_correction_count,
                cfg.batch_size,
            )
            self._layers_info[f"{target_lname}/train_type"] = "correct_bias"
        else:
            return

        self._layers_info[f"{target_lname}/successfully_run"] = True
        self._layers_info[f"{target_lname}/block_layers"] = list(block_model.flow.nodes)

    def train_layer(self, block_model: HailoModel, target_layer: str, native_data, quant_data, verbose=1):
        """
        Train a single layer using the adaround algorithm
        """
        inputs_cache_list = get_cache_list(block_model.flow.input_nodes, quant_data)
        outputs_cache_list = get_cache_list(block_model.flow.output_nodes, native_data)
        config = self.get_algo_config()
        # TODO [optional]: move the layer config resolution to get_config_by_layer logic
        resolved_cfg = self.get_resolved_layer_train_config(config.layers[target_layer])
        train_model = AdaRoundTrainModel(block_model, [target_layer], resolved_cfg.train_bias, self._logger)
        optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
        train_model.compile(
            optimizer=optimizer,
            loss=l2_loss,
            metrics=[l2_loss],
            run_eagerly=config.eager,
        )
        train_cache = self.create_cache_dir("train")
        dali_cache = get_layer_cache_dir(train_cache, target_layer)
        dataset, x_shape, y_shape = dali_train_dataset(
            inputs_cache_list,
            outputs_cache_list,
            batch_size=config.batch_size,
            count=resolved_cfg.dataset_size,
            base_dir=dali_cache,
            shuffle=config.shuffle,
            seed=config.seed,
        )

        total_train_iters = self._epochs_to_iterations(resolved_cfg.epochs)
        warmup_epochs = resolved_cfg.epochs * resolved_cfg.warmup
        warmup_iters = self._epochs_to_iterations(warmup_epochs)
        callbacks = AdaroundV2Callbacks(total_train_iters, warmup_iters, log_samples=config.log_samples)
        try:
            train_model.fit(
                dataset,
                total_train_iters,
                resolved_cfg.b_range,
                resolved_cfg.decay_start,
                resolved_cfg.warmup,
                resolved_cfg.weight,
                x_shape,
                y_shape,
                callbacks=callbacks,
                verbose=0,
            )
        except tf.errors.ResourceExhaustedError as err_msg:
            self._logger.debug(err_msg)
            raise AccelerasResourceError(
                f"GPU memory has been exhausted. Please try to use {self._name} with "
                f"lower batch size or run on CPU.",
            ) from err_msg
        except KeyboardInterrupt:
            self._logger.warning(
                f"Training of layer {target_layer} has been terminated by the user, proceed at your own peril",
            )
        self._save_logged_samples(target_layer, callbacks.history, resolved_cfg)
        block_model.set_lossy()
        shutil.rmtree(dali_cache, ignore_errors=True)

    @staticmethod
    def _get_layer_cfg_field(layer_config, field, algo_config):
        value = getattr(layer_config, field)
        if value is not None:
            return value
        value = getattr(algo_config, field)
        return value

    def _save_logged_samples(self, target_layer, logged_samples, layer_cfg: ResolvedLayerTrainConfig):
        for sample_name, sample_data in logged_samples.items():
            self._layers_info[f"{target_layer}/history/{sample_name}"] = sample_data
        self._layers_info[f"{target_layer}/history/epochs"] = layer_cfg.epochs
        self._layers_info[f"{target_layer}/history/dataset_size"] = layer_cfg.dataset_size
        self._layers_info[f"{target_layer}/history/batch_size"] = layer_cfg.batch_size
        self._layers_info[f"{target_layer}/history/b_range"] = layer_cfg.b_range

    def _get_kernel_lossy_element(self, layer: BaseHailoLayer):
        elements = []
        for op in layer.atomic_ops:
            element = getattr(op.weight_lossy_elements, "kernel", None)
            if element:
                elements.append(element)
        if len(elements) != 1:
            raise AdaRoundError(f"Trying to modify layer with {element} kernels")
        return elements[0]

    @classmethod
    def get_config_by_layer(
        cls,
        algo_config,
        layers,
    ) -> Dict[str, LayerAdaRoundConfig]:
        """
        Get the per layer configurations of adaround
        """
        if not isinstance(layers, dict):
            layers = {layer: layer for layer in layers}
        config_by_layer = dict()
        for model_lname, config_lname in layers.items():
            config_by_layer[model_lname] = algo_config.layers.get(config_lname, LayerAdaRoundConfig.get_default())
        return config_by_layer

    @classmethod
    def is_trainable(cls, layer) -> bool:
        """Check if given layer is trainable or not"""
        return type(layer) in cls.TRAINABLE_LAYERS and layer.trainable  # exact type, excluding sub-types

    @classmethod
    def resolve_policy(
        cls,
        layer,
        default_policy: FeaturePolicy,
        layer_policy: LayerFeaturePolicy,
        mode: AdaRoundMode,
    ) -> FeaturePolicy:
        """Resolve the feature policy with per layer policy to get the actual layer policy"""
        if layer_policy == LayerFeaturePolicy.allowed:
            if cls.has_suitable_lossy_element(layer, mode):
                policy = default_policy
            else:
                policy = FeaturePolicy.disabled
        else:
            policy = FeaturePolicy(layer_policy.value)
        return policy

    def get_resolved_layer_train_config(self, layer_config, algo_config=None) -> ResolvedLayerTrainConfig:
        resolved_values = dict()
        if algo_config is None:
            algo_config = self.get_algo_config()
        for key in LayerAdaRoundConfig.keys() - {"policy"}:
            value = self._get_layer_cfg_field(layer_config, key, algo_config)
            resolved_values[key] = value
        return ResolvedLayerTrainConfig(**resolved_values)

    @classmethod
    def has_suitable_lossy_element(cls, layer, adaround_mode: AdaRoundMode) -> bool:
        """
        Check if the layer has a suitable lossy element based on given adaround mode
        """
        for lossy_ele in layer.get_weight_lossy_elements():
            kernel_quant_elem = getattr(lossy_ele, "kernel", None)
            if kernel_quant_elem is None:
                continue
            kernel_bits = kernel_quant_elem.bits
            if kernel_bits is None:
                continue
            if not kernel_quant_elem.signed:
                continue
            if adaround_mode == AdaRoundMode.train_4bit and kernel_bits == 4:
                return True
            if adaround_mode == AdaRoundMode.train_all:
                return True
        return False

    def should_train(self, layer) -> bool:
        """
        Check if a layer should be trained or not,
        based on layer support, layer configuration, and algo configuration
        """
        cfg = self.get_algo_config()
        layer_cfg = cfg.layers.get(layer.full_name, LayerAdaRoundConfig())
        if not self.is_trainable(layer):
            return False
        policy = self.resolve_policy(layer, cfg.policy, layer_cfg.policy, cfg.mode)
        return policy == FeaturePolicy.enabled

    def _epochs_to_iterations(self, epochs):
        config = self.get_algo_config()
        return epochs * config.dataset_size // config.batch_size

    def finalize_global_cfg(self, algo_config: AdaRoundConfig):
        # Is there an better way to get dataset length?
        if not self.should_skip_algo():
            self.check_dataset_length(algo_config, "dataset_size", self._dataset, warning_if_larger=True)
            self.check_batch_size(algo_config, "dataset_size", "batch_size")

    def _get_valid_layer_cfg(self, lname, cfg):
        if not self.is_trainable(self._model.layers[lname]):
            cfg = {}
        return cfg
