"""
This module implements the AdaRound algorithm for an acceleras model
https://arxiv.org/abs/2004.10568
"""

import logging
import os
import shutil
from functools import partial
from typing import Dict

import numpy as np
import tensorflow as tf
from verboselogs import VERBOSE

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerAdaRoundConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import AdaRoundConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    AdaRoundMode,
    FeaturePolicy,
    LayerFeaturePolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasResourceError,
    AdaRoundError,
)
from hailo_model_optimization.algorithms.ada_round.layer_training_model import AdaRoundTrainModel
from hailo_model_optimization.algorithms.ada_round.loss_utils import AdaroundV1Callbacks, l2_loss
from hailo_model_optimization.algorithms.bias_correction.bias_correction import BiasCorrection
from hailo_model_optimization.algorithms.dali_utils import dali_train_dataset
from hailo_model_optimization.algorithms.dali_utils.mock_dali_dataset import cache_list_to_dataset
from hailo_model_optimization.algorithms.layer_by_layer import LayerByLayer


class AdaRound(LayerByLayer):
    """
    This class implements adaround algorithm for an acceleras model

    Args:
        model: Acceleras model to apply the algo to
        config_params: Model config
        dataset: data for the layers' train
        work_dir: path for files cache

    """

    # TODO: fix support for deconv layer
    TRAINABLE_LAYERS = {HailoConv, HailoDense, HailoConvAdd, HailoDepthwise}

    def __init__(
        self,
        model: HailoModel,
        config_params: ModelOptimizationConfig,
        dataset: tf.data.Dataset,
        *args,
        work_dir=None,
        logger_level=logging.DEBUG,
        **kwargs,
    ):
        # TODO: add temperature, maybe control gamma and zeta, learning config
        work_dir = work_dir if work_dir is not None else "adaround"
        super().__init__(
            model,
            config_params,
            "Adaround",
            dataset,
            *args,
            work_dir=work_dir,
            logger_level=logger_level,
            **kwargs,
        )
        self.compiled_loss = None
        self.optimizer = None
        layers = {layer.full_name: self.get_layer_name_in_config(layer) for layer in self._model.layers.values()}
        algo_cfg = self.get_algo_config()
        config_by_layer = self.get_config_by_layer(algo_cfg, layers)
        algo_cfg.layers = config_by_layer
        self._lossy_interlayer_results = None
        self._layers_info = {f"{layer}/successfully_run": False for layer in self._model.flow.toposort()}

    def _setup(self):
        super()._setup()
        config = self.get_algo_config()
        self._logger.info(f"Using dataset with {config.bias_correction_count} entries for bias correction")

    def export_statistics(self):
        return self._layers_info

    @property
    def warning_if_larger_dataset(self) -> bool:
        return True

    def get_dataset_size(self):
        cfg = self.get_algo_config()
        return cfg.dataset_size

    def train_layer(self, layer, inputs_cache_list, output_dir, verbose=1):
        """
        Train a single layer using the adaround algorithm
        """
        config = self.get_algo_config()
        train_model = AdaRoundTrainModel(layer, config.train_bias, self._logger)
        train_model.compile(
            optimizer="adam",
            loss=l2_loss,
            metrics=[l2_loss],
            run_eagerly=config.eager,
        )
        outputs_cache_list = [os.path.join(output_dir, self.get_cache_basename(i)) for i in range(layer.num_outputs)]
        dali_cache = self._get_cache_dir(layer.full_name)
        dataset, x_shape, y_shape = dali_train_dataset(
            inputs_cache_list,
            outputs_cache_list,
            batch_size=config.batch_size,
            count=config.dataset_size,
            base_dir=dali_cache,
            shuffle=config.shuffle,
            seed=config.seed,
        )
        total_train_iters = self._epochs_to_iterations(config.epochs)
        warmup_epochs = config.epochs * config.warmup
        warmup_iters = self._epochs_to_iterations(warmup_epochs)
        callbacks = AdaroundV1Callbacks(total_train_iters, warmup_iters, verbose)
        try:
            train_model.fit(
                dataset,
                total_train_iters,
                config.b_range,
                config.decay_start,
                config.warmup,
                config.weight,
                x_shape,
                y_shape,
                callbacks=callbacks,
                verbose=0,
            )
        except tf.errors.ResourceExhaustedError as err_msg:
            self._logger.debug(err_msg)
            raise AccelerasResourceError(
                f"GPU memory has been exhausted. Please try to use {self._name} with "
                f"lower batch size or run on CPU.",
            ) from err_msg
        shutil.rmtree(dali_cache, ignore_errors=True)

    def _lbl_post_layer_logic(
        self,
        acceleras_layer,
        curr_inputs_cache_list,
        curr_outputs_parent_cache,
        inferred_layers,
    ):
        """Triggers the adaround or a bias correction after a reference has been collected"""
        lname = acceleras_layer.full_name
        curr_inputs_cache_list_lossy = self._get_layer_inputs_cache_list(
            lname,
            self._lossy_interlayer_results,
        )
        config = self.get_algo_config()
        layer_cfg = config.layers.get(acceleras_layer.full_name)
        changed = True
        if self.should_train(acceleras_layer, config, layer_cfg):
            self._logger.debug(f"Training {acceleras_layer.full_name}")
            self.train_layer(acceleras_layer, curr_inputs_cache_list_lossy, curr_outputs_parent_cache)
        elif config.train_bias and BiasCorrection.is_correctable(acceleras_layer):
            self._logger.debug(f"Correcting bias {acceleras_layer.full_name}")
            self.correct_bias(acceleras_layer, curr_outputs_parent_cache)
        else:
            self._logger.debug(f"Skipping {acceleras_layer.full_name}")
            changed = False
        if changed:
            acceleras_layer.disable_lossy()
            kernel_quant_elem = self._get_kernel_lossy_element(acceleras_layer)
            kernel_quant_elem.enable()
            acceleras_layer.enforce_internal_encoding()
            self._layers_info[f"{lname}/successfully_run"] = True
        curr_outputs_parent_cache_lossy = self._infer_layer(acceleras_layer, curr_inputs_cache_list_lossy)
        self._lossy_interlayer_results[lname] = curr_outputs_parent_cache_lossy
        if self._logger.isEnabledFor(VERBOSE):
            self.check_l2_loss_post_train(acceleras_layer, curr_outputs_parent_cache, curr_outputs_parent_cache_lossy)

        acceleras_layer.enable_lossy()
        acceleras_layer.enforce_internal_encoding()
        self._clean_results(self._lossy_interlayer_results, inferred_layers)

    def check_l2_loss_post_train(self, acceleras_layer, curr_outputs_parent_cache, curr_outputs_parent_cache_lossy):
        config = self.get_algo_config()
        dataset, _ = cache_list_to_dataset(
            [curr_outputs_parent_cache_lossy, curr_outputs_parent_cache],
            config.dataset_size,
        )
        l2_loss_dataset = dataset.map(lambda x: l2_loss(x[0], x[1]))
        l2_loss_data = np.array([i.numpy() for i in l2_loss_dataset])
        l2_loss_result = l2_loss_data.mean()
        self._logger.verbose(f"{acceleras_layer.full_name} l2_loss {l2_loss_result:.05f}")

    def correct_bias(self, acceleras_layer, output_reference_dir):
        if acceleras_layer.num_outputs != 1:
            raise AccelerasImplementationError("Correcting bias with multiple outputs is not defined")
        inputs_cache_list_lossy = self._get_layer_inputs_cache_list(
            acceleras_layer.full_name,
            self._lossy_interlayer_results,
        )
        config = self.get_algo_config()
        acceleras_layer.disable_lossy()
        kernel_quant_elem = self._get_kernel_lossy_element(acceleras_layer)
        kernel_quant_elem.enable()
        acceleras_layer.enforce_internal_encoding()
        count = config.bias_correction_count
        output_pre_correction_dir = self._infer_layer(
            acceleras_layer,
            inputs_cache_list_lossy,
            count=count,
        )
        quant_cache = os.path.join(output_pre_correction_dir, self.get_cache_basename(0))
        ref_cache = os.path.join(output_reference_dir, self.get_cache_basename(0))
        dataset, _ = cache_list_to_dataset([quant_cache, ref_cache], count)
        diff_dataset = dataset.map(lambda x: x[0] - x[1])
        axes_to_reduce = BiasCorrection.get_bias_reduce_axes(acceleras_layer, diff_dataset.element_spec.shape)
        mean_func = partial(tf.reduce_mean, axis=axes_to_reduce)
        diff_per_channel_dataset = diff_dataset.map(mean_func)
        diff_per_channel_np = np.array([i.numpy() for i in diff_per_channel_dataset])
        diff_per_channel = diff_per_channel_np.mean(axis=0)
        acceleras_layer.bias.assign(acceleras_layer.bias.numpy() - diff_per_channel)
        acceleras_layer.enforce_internal_encoding()
        self._delete_cache(output_pre_correction_dir)

    def _get_kernel_lossy_element(self, layer: BaseHailoLayer):
        elements = []
        for op in layer.atomic_ops:
            element = getattr(op.weight_lossy_elements, "kernel", None)
            if element:
                elements.append(element)
        if len(elements) != 1:
            raise AdaRoundError(f"Trying to modify layer with {element} kernels")
        return elements[0]

    def get_algo_config(self):
        """Return the current algorithm configuration"""
        return self._model_config.adaround

    def should_skip_algo(self):
        """
        Checks if the algorithm is enabled to any layer
        Return:
            boolean value that indicates if adaround is applied to at least one layer
        """
        cfg = self.get_algo_config()
        for lname, layer in self._model.layers.items():
            layer_cfg = cfg.layers.get(lname, LayerAdaRoundConfig())
            should_train = self.should_train(layer, cfg, layer_cfg)
            if should_train:
                return False
        return True

    def _lbl_setup_logic(self):
        """Init interlayer results for lossy data"""
        self._lossy_interlayer_results = dict()

    def _lbl_pre_layer_logic(self, acceleras_layer):
        """Set the layer to be lossless before collectin reference"""
        acceleras_layer.disable_lossy()
        if acceleras_layer.built:
            acceleras_layer.enforce_internal_encoding()

    @classmethod
    def get_config_by_layer(
        cls,
        algo_config,
        layers,
    ) -> Dict[str, LayerAdaRoundConfig]:
        """
        Get the per layer configurations of adaround
        """
        if not isinstance(layers, dict):
            layers = {layer: layer for layer in layers}
        config_by_layer = dict()
        for model_lname, config_lname in layers.items():
            config_by_layer[model_lname] = algo_config.layers.get(config_lname, LayerAdaRoundConfig.get_default())
        return config_by_layer

    @classmethod
    def is_trainable(cls, layer) -> bool:
        """Check if given layer is trainable or not"""
        return type(layer) in cls.TRAINABLE_LAYERS  # exact type, excluding sub-types

    @classmethod
    def resolve_policy(
        cls,
        layer,
        default_policy: FeaturePolicy,
        layer_policy: LayerFeaturePolicy,
        mode: AdaRoundMode,
    ) -> FeaturePolicy:
        """Resolve the feature policy with per layer policy to get the actual layer policy"""
        if layer_policy == LayerFeaturePolicy.allowed:
            if cls.has_suitable_lossy_element(layer, mode):
                policy = default_policy
            else:
                policy = FeaturePolicy.disabled
        else:
            policy = FeaturePolicy(layer_policy.value)
        return policy

    @classmethod
    def has_suitable_lossy_element(cls, layer, adaround_mode: AdaRoundMode) -> bool:
        """
        Check if the layer has a suitable lossy element based on given adaround mode
        """
        for atomic_op in layer.atomic_ops:
            kernel_quant_elem = getattr(atomic_op.weight_lossy_elements, "kernel", None)
            if kernel_quant_elem is None:
                continue
            kernel_bits = kernel_quant_elem.bits
            if kernel_bits is None:
                continue
            if not kernel_quant_elem.signed:
                continue
            if adaround_mode == AdaRoundMode.train_4bit and kernel_bits == 4:
                return True
            if adaround_mode == AdaRoundMode.train_all:
                return True
        return False

    @classmethod
    def should_train(
        cls,
        layer,
        cfg: AdaRoundConfig,
        layer_cfg: LayerAdaRoundConfig,
    ) -> bool:
        """
        Check if a layer should be trained or not,
        based on layer support, layer configuration, and algo configuration
        """
        if not cls.is_trainable(layer):
            return False

        if layer_cfg is None:
            layer_cfg = LayerAdaRoundConfig()

        policy = cls.resolve_policy(layer, cfg.policy, layer_cfg.policy, cfg.mode)
        return policy == FeaturePolicy.enabled

    def _epochs_to_iterations(self, epochs):
        config = self.get_algo_config()
        return epochs * config.dataset_size // config.batch_size
