"""
This module implements the AdaRound algorithm for an acceleras model
https://arxiv.org/abs/2004.10568
"""

import logging
from typing import Dict, List

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerAdaRoundConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import BlockRoundTrainingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeaturePolicy, LayerFeaturePolicy
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasResourceError, AdaRoundError
from hailo_model_optimization.acceleras.utils.dataset_util import verify_dataset_size
from hailo_model_optimization.algorithms.ada_round.block_round_model import AdaRoundTrainModel
from hailo_model_optimization.algorithms.ada_round.loss_utils_v2 import AdaroundV2Callbacks, l2_loss
from hailo_model_optimization.algorithms.block_by_block.block_by_block_v2 import BlockByBlock
from hailo_model_optimization.algorithms.block_by_block.block_heuristic import block_communities, naive_blocks
from hailo_model_optimization.algorithms.dali_utils.data_feeder_tfrecord import DataFeederTFRecord


class BlockRoundTraining(BlockByBlock):
    """
    This class implements CommunityAdaptiveRounding algorithm for an acceleras model

    Args:
        model: Acceleras model to fit with BlockRoundTraining.
        fp_model: the reference Acceleras model (might be a native or a full precision model).
        config_params: model configuration.
        dataset: TF dataset to training with.
        work_dir: path for files cache

    """

    # TODO: fix support for deconv layer
    TRAINABLE_LAYERS = {HailoConv, HailoDense, HailoConvAdd, HailoDepthwise}

    def __init__(
        self,
        model: HailoModel,
        fp_model: HailoModel,
        model_config: ModelOptimizationConfig,
        dataset: tf.data.Dataset,
        *args,
        work_dir: str = None,
        logger_level=logging.DEBUG,
        logger=None,
        **kwargs,
    ):
        super().__init__(
            model,
            fp_model,
            model_config,
            "Block Round Training",
            dataset,
            *args,
            work_dir=work_dir,
            logger_level=logger_level,
            logger=logger,
            **kwargs,
        )
        self.compiled_loss = None
        layers = {layer.full_name: self.get_layer_name_in_config(layer) for layer in self._model.layers.values()}
        self._config_by_layer = self.get_config_by_layer(self._model_config, layers)
        self._layers_info = {f"{layer}/successfully_run": False for layer in self._model.flow.toposort()}
        # data_feeder is initialized in the _setup, after should_skip_algo() is called
        self.data_feeder = None

    def _setup(self):
        cfg = self.get_algo_config()
        verify_dataset_size(self._dataset, cfg.dataset_size, warning_if_larger=True, logger=self._logger)
        # Initialize DataFeederTFRecord
        self.data_feeder = DataFeederTFRecord(
            device=cfg.device,
            device_id=cfg.device_id,
            use_dali=cfg.use_dali,
            num_threads=cfg.num_threads,
            compression_type=self.get_compression_type(),
            logger=self._logger,
        )
        self.check_storage_usage(dali_cache=self.data_feeder.use_dali)
        self._logger.info(f"Using dataset with {cfg.dataset_size} entries for {self._name}")

    def export_statistics(self):
        return self._layers_info

    def get_algo_config(self):
        """Return current algorithm configuration."""
        return self._model_config.block_round_training

    def get_dataset_size(self) -> int:
        return self.get_algo_config().dataset_size

    def get_batch_size(self) -> int:
        return self.get_algo_config().batch_size

    def get_compression_type(self) -> str:
        compression_type_value = self.get_algo_config().compression_type.value
        compression_type = "" if compression_type_value == "none" else compression_type_value.upper()
        return compression_type

    def get_eager_mode(self) -> bool:
        """Enable or disable eager mode."""
        return self.get_algo_config().run_eagerly == FeaturePolicy.enabled

    def get_internal_encoding(self) -> bool:
        """Enable or disable internal encodings between ops and layers."""
        return self.get_algo_config().internal_encoding == FeaturePolicy.enabled

    def get_train_bias(self) -> bool:
        """ "Enable or disable training for the bias weights."""
        return self.get_algo_config().train_bias == FeaturePolicy.enabled

    def should_skip_algo(self):
        """
        Checks if the algorithm is enabled to any layer
        Return:
            boolean value that indicates if adaround is applied to at least one layer
        """
        for layer in self._model.layers.values():
            should_train = self.should_train(layer)
            if should_train:
                return False
        return True

    def get_blocks(self) -> Dict[str, ModelFlow]:
        resolution = self.get_algo_config().resolution
        if resolution < np.inf:
            # Run BlockRound
            return block_communities(self._model, resolution=resolution)
        else:
            # Run AdaRound (blocks are single layers)
            self._logger.info(f"Running AdaRound instead of BlockRound ({resolution = })")

            def filter_func(layer):
                return self.should_train(layer)

            return naive_blocks(self._model, filter_func)

    def _run_int(self):
        self._comparative_run(pre_quant_cb=self.core_logic)

    def core_logic(
        self,
        block_model: HailoModel,
        native_results: Dict[str, str],
        quant_results: Dict[str, str],
    ):
        """
        Core logic for block in the adaround algorithm.
        apply adaptive rounding if the output layer can be trained
        apply bias correction if it has bias but cannot be trained
        otherwise do nothing
        """
        trained_layers = []
        for lname, layer in block_model.layers.items():
            if self.should_train(layer):
                trained_layers.append(lname)
        if len(trained_layers) > 0:
            self.train_community(block_model, trained_layers, native_results, quant_results)
        else:
            return

        for target_layer in trained_layers:
            self._layers_info[f"{target_layer}/successfully_run"] = True
            self._layers_info[f"{target_layer}/block_layers"] = list(block_model.flow.nodes)

    def train_community(self, block_model: HailoModel, trained_layers: List[str], native_data, quant_data, verbose=1):
        """Train a single layer using the adaround algorithm"""
        cfg = self.get_algo_config()

        # Initialize Dataset
        inputs_cache_list = self.data_feeder.read_cache_list(block_model.flow.input_nodes, quant_data)
        outputs_cache_list = self.data_feeder.read_cache_list(block_model.flow.output_nodes, native_data)

        total_train_iters = self._epochs_to_iterations(cfg.epochs)
        warmup_epochs = cfg.epochs * cfg.warmup
        warmup_iters = self._epochs_to_iterations(warmup_epochs)

        dataset = self.data_feeder.cache_to_dataset(
            input_nodes=block_model.flow.input_nodes,
            input_interlayer_results=inputs_cache_list,
            output_nodes=block_model.flow.output_nodes,
            output_interlayer_results=outputs_cache_list,
            batch_size=self.get_batch_size(),
        )
        dataset = dataset.cache()
        dataset = dataset.repeat(cfg.epochs)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        # The Keras model
        train_model = AdaRoundTrainModel(block_model, trained_layers, self.get_train_bias(), self._logger)
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)
        train_model.compile(
            optimizer=optimizer,
            loss=l2_loss,
            run_eagerly=self.get_eager_mode(),
        )

        # Callbacks
        callbacks = [
            AdaroundV2Callbacks(total_train_iters, warmup_iters, log_samples=cfg.log_samples),
        ]
        try:
            history = train_model.fit(
                dataset,
                total_train_iters,
                cfg.b_range,
                cfg.decay_start,
                cfg.warmup,
                cfg.weight,
                steps_per_epoch=total_train_iters,
                callbacks=callbacks,
                verbose=0,
            )
        except tf.errors.ResourceExhaustedError as err_msg:
            self._logger.debug(err_msg)
            raise AccelerasResourceError(
                f"GPU memory has been exhausted. Please try to use {self._name} with "
                f"lower batch size or run on CPU.",
            ) from err_msg

        self._save_logged_samples(block_model.flow.output_nodes, history.history)
        block_model.set_lossy()

    def _save_logged_samples(self, target_layers, logged_samples):
        for sample_name, sample_data in logged_samples.items():
            for target_layer in target_layers:
                self._layers_info[f"{target_layer}/history/{sample_name}"] = sample_data

    def _get_kernel_lossy_element(self, layer: BaseHailoLayer):
        elements = []
        for op in layer.atomic_ops:
            element = getattr(op.weight_lossy_elements, "kernel", None)
            if element:
                elements.append(element)
        if len(elements) != 1:
            raise AdaRoundError(f"Trying to modify layer with {element} kernels")
        return elements[0]

    @classmethod
    def get_config_by_layer(
        cls,
        model_config,
        layers,
    ) -> Dict[str, LayerAdaRoundConfig]:
        """
        Get the per layer configurations of adaround
        """
        if not isinstance(layers, dict):
            layers = {layer: layer for layer in layers}
        config_by_layer = dict()
        return config_by_layer

    @classmethod
    def is_trainable(cls, layer) -> bool:
        """Check if given layer is trainable or not"""
        return type(layer) in cls.TRAINABLE_LAYERS and layer.trainable  # exact type, excluding sub-types

    @classmethod
    def resolve_policy(
        cls,
        default_policy: FeaturePolicy,
        layer_policy: LayerFeaturePolicy,
    ) -> FeaturePolicy:
        """Resolve the feature policy with per layer policy to get the actual layer policy"""
        if layer_policy == LayerFeaturePolicy.allowed:
            policy = default_policy
        else:
            policy = FeaturePolicy(layer_policy.value)
        return policy

    def should_train(self, layer) -> bool:
        """
        Check if a layer should be trained or not,
        based on layer support, layer configuration, and algo configuration
        """
        cfg = self.get_algo_config()
        layer_cfg = self._config_by_layer.get(layer.full_name, LayerAdaRoundConfig())
        if not self.is_trainable(layer):
            return False

        policy = self.resolve_policy(cfg.policy, layer_cfg.policy)
        return policy == FeaturePolicy.enabled

    def _epochs_to_iterations(self, epochs):
        config = self.get_algo_config()
        return epochs * config.dataset_size // config.batch_size

    def finalize_global_cfg(self, algo_config: BlockRoundTrainingConfig):
        # Is there an better way to get dataset length?
        if not self.should_skip_algo():
            self.check_dataset_length(algo_config, "dataset_size", self._dataset, warning_if_larger=True)
            self.check_batch_size(algo_config, "dataset_size", "batch_size")

    def _get_valid_layer_cfg(self, lname, cfg):
        if not self.is_trainable(self._model.layers[lname]):
            cfg = {}
        return cfg
