"""
This module implements the AdaRound algorithm for an acceleras model
https://arxiv.org/abs/2004.10568
"""

from typing import List

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.lossy_elements.quant_element import AdaRoundQuantElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.algorithms.ada_round.loss_regularizer_v2 import AdaRoundRegularizer
from hailo_model_optimization.algorithms.ada_round.loss_utils import DirectMetric


class AdaRoundTrainModel(tf.keras.Model):
    """
    Model that wraps a single layer, the given layer will be trained
    Args:
        block_model: Acceleras model to train
        trained_layers: a list of layer names to train.
        train_bias: enable or disable training for the bias weights.
    """

    def __init__(
        self,
        block_model: HailoModel,
        trained_layers: List[str],
        train_bias: bool,
        logger,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.logger = logger
        self.ada_round_elements = dict()
        self._original_lossy_elements = dict()
        self.round_regularizer = None
        self.block = block_model

        for lname, layer in block_model.layers.items():
            if lname in trained_layers:
                self._setup_layer(layer, train_bias)
            else:
                layer.trainable = False

        self._init_custom_metrics(trained_layers)
        self.stop_warmup_counter = 0

    def _init_custom_metrics(self, trained_layers: List[str]):
        self._round_metric = DirectMetric("total_round_loss")
        self._temperature_metric = DirectMetric("annealing_b")

    def fit(
        self,
        x,
        iters,
        b_range,
        decay_start,
        warmup,
        weight,
        *args,
        verbose=1,
        epochs=1,
        steps_per_epoch=None,
        callbacks=None,
        validation_split=0,
        validation_data=None,
        initial_epoch=0,
        validation_steps=None,
        validation_freq=1,
    ):
        """
        Wraps the keras' fit function and initializes a regularizer.
        """
        steps_per_epoch = iters if steps_per_epoch is None else steps_per_epoch

        # Used in model's call
        self.round_regularizer = AdaRoundRegularizer(
            max_count=iters,
            b_range=b_range,
            decay_start=decay_start,
            warmup=warmup,
            weight=weight,
        )

        # Keras fit - start training
        history = super().fit(
            x,
            epochs=epochs,
            verbose=verbose,
            callbacks=callbacks,
            validation_split=validation_split,
            validation_data=validation_data,
            initial_epoch=initial_epoch,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
        )

        for layer in self.block.layers.values():
            if hasattr(layer, "enforce_internal_encoding"):
                layer.enforce_internal_encoding()  # ensure kernel_scale will be define without scope

        for lname in self.ada_round_elements:
            self._finalize_layer(self.block.layers[lname])

        return history

    def call(self, inputs, training=None, mask=None):
        """
        unpads the input from the dali dataset, updates the round regularizer,
        and calls the trained layer
        """
        total_round_loss = 0.0
        for trained_element in self.ada_round_elements.values():
            round_value = trained_element.get_fractional_value()
            round_loss = self.round_regularizer(
                round_value, tf.cast(self.optimizer.iterations, dtype=round_value.dtype)
            )
            total_round_loss += round_loss

        self.add_loss([total_round_loss])  # * add regularization factor to the loss
        self._round_metric.update_state(total_round_loss)
        annealing_b = self.round_regularizer.get_beta(tf.cast(self.optimizer.iterations, dtype=round_value.dtype))
        self._temperature_metric.update_state(annealing_b)

        if isinstance(inputs, tuple):
            inputs = list(inputs)
        outputs = self.block(inputs)
        if isinstance(outputs, tf.Tensor):
            outputs = (outputs,)

        return outputs

    def _setup_layer(self, layer: BaseHailoConv, train_bias: bool):
        """
        Prepares the layer for the AdaRound train
        """
        # Prepare layer
        for atomic_op in layer.atomic_ops:
            atomic_op.trainable = False
        if train_bias:
            layer.bias_add_op.trainable = True
        layer.disable_lossy()
        if layer.built:
            layer.enforce_internal_encoding()

        # Prepare conv op's bit reducer
        bits = layer.conv_op.weight_lossy_elements.kernel.bits
        signed = layer.conv_op.weight_lossy_elements.kernel.signed
        ada_round_br = AdaRoundQuantElement(signed=signed, bits=bits)
        var_value = ada_round_br.get_initial_var_value(
            layer.conv_op.final_numeric_kernel,
        )
        var_initializer = tf.initializers.Constant(var_value)
        var = self.add_weight(
            name="adaround_var",
            trainable=True,
            initializer=var_initializer,
            shape=var_value.shape,
        )
        ada_round_br.set_var(var)
        ada_round_br.enable()
        self._original_lossy_elements[layer.full_name] = layer.conv_op.weight_lossy_elements.kernel
        layer.conv_op.weight_lossy_elements.kernel = ada_round_br
        self.ada_round_elements[layer.full_name] = ada_round_br
        if layer.built:
            layer.enforce_internal_encoding()

    def _finalize_layer(self, layer):
        round_val = self.ada_round_elements[layer.full_name].get_fractional_value()
        kscale = layer.conv_op.kernel_scale
        new_kernel = (np.floor(layer.conv_op.kernel / kscale) + round_val) * kscale
        layer.conv_op.kernel.assign(new_kernel)
        layer.conv_op.weight_lossy_elements.kernel = self._original_lossy_elements[layer.full_name]
