"""
This module implements the AdaRound algorithm for an acceleras model
https://arxiv.org/abs/2004.10568
"""

from typing import List

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.lossy_elements.quant_element import AdaRoundQuantElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.algorithms.ada_round.loss_regularizer import AdaRoundRegularizer
from hailo_model_optimization.algorithms.ada_round.loss_utils import DirectMetric
from hailo_model_optimization.algorithms.dali_utils import tf_pad_outputs, tf_unpad_input


class AdaRoundTrainModel(tf.keras.Model):
    """
    Model that wraps a single layer, the given layer will be trained

    Args:
        layer: acceleras layer to train
        train_bias: whether the bias should be trained as well or not

    """

    def __init__(
        self,
        block_model: HailoModel,
        trained_layers: List[str],
        train_bias,
        logger,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.logger = logger
        self.ada_round_elements = dict()
        self._original_lossy_elements = dict()
        self.round_regularizer = None
        self.block = block_model
        for lname, layer in block_model.layers.items():
            if lname in trained_layers:
                self._setup_layer(layer, train_bias)
            else:
                layer.trainable = False
        self._x_shape = None  # used for DALI slicing / padding
        self._y_shape = None  # used for DALI slicing / padding
        self._init_custom_metrics(trained_layers)

    def _init_custom_metrics(self, trained_layers):
        self._round_metric = DirectMetric("round_loss")
        self._per_layer_round_loss = {layer: DirectMetric(f"{layer}_round_loss") for layer in trained_layers}
        self._per_layer_round_value = {layer: DirectMetric(f"{layer}_mean_round_value") for layer in trained_layers}
        self._temperature_metric = DirectMetric("annealing_b")

    @property
    def metrics(self):
        """
        Returns the list of metrics to be used in the model
        """
        return (
            [self._round_metric, self._temperature_metric]
            + list(self._per_layer_round_loss.values())
            + list(self._per_layer_round_value.values())
        )

    def build(self, input_shape):
        """
        Simple build function to ensure the model is built, avoiding build warnings.
        Basic implementation of the build function, sets model.built to True.
        """
        super().build(input_shape)

    def fit(
        self,
        x,
        iters,
        b_range,
        decay_start,
        warmup,
        weight,
        x_shape,
        y_shape,
        verbose=1,
        callbacks=None,
        validation_split=0,
        validation_data=None,
        initial_epoch=0,
        validation_steps=None,
        validation_freq=1,
    ):
        """Wraps the keras' fit function and initializes a regularizer"""
        # Used in model's call
        self.round_regularizer = AdaRoundRegularizer(
            max_count=iters,
            b_range=b_range,
            decay_start=decay_start,
            warmup=warmup,
            weight=weight,
        )
        self._x_shape = x_shape
        self._y_shape = y_shape

        retval = super().fit(
            x,
            epochs=1,
            verbose=verbose,
            callbacks=callbacks,
            validation_split=validation_split,
            validation_data=validation_data,
            initial_epoch=initial_epoch,
            steps_per_epoch=iters,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
        )

        for layer in self.block.layers.values():
            if hasattr(layer, "enforce_internal_encoding"):
                layer.enforce_internal_encoding()  # ensure kernel_scale will be define without scope

        for lname in self.ada_round_elements:
            self._finalize_layer(self.block.layers[lname])

        return retval

    def call(self, inputs, training=None, mask=None):
        """
        unpads the input from the dali dataset, updates the round regularizer,
        and calls the trained layer
        """
        total_round_loss = 0
        for lname, trained_element in self.ada_round_elements.items():
            round_value = trained_element.get_fractional_value()
            round_loss = self.round_regularizer(
                round_value, tf.cast(self.optimizer.iterations, dtype=round_value.dtype)
            )
            total_round_loss += round_loss
            self._per_layer_round_loss[lname].update_state(round_loss)
            mean_round_value = tf.reduce_mean(tf.abs(round_value - 0.5))
            self._per_layer_round_value[lname].update_state(mean_round_value)
            self.add_loss([round_loss])
        self._round_metric.update_state(total_round_loss)
        annealing_b = self.round_regularizer.liner_temp_decay(
            tf.cast(self.optimizer.iterations, dtype=round_value.dtype)
        )
        self._temperature_metric.update_state(annealing_b)
        inputs_tensors = tf_unpad_input(inputs, self._x_shape)
        outputs = self.block(inputs_tensors)
        if isinstance(outputs, tf.Tensor):
            outputs = (outputs,)
        outputs = tf_pad_outputs(outputs, stack_axis=1)
        # outputs = tf.stack(outputs, axis=1)
        return outputs

    def _setup_layer(self, layer, train_bias):
        """Prepares the layer for the AdaRound train"""
        # Prepare layer
        for atomic_op in layer.atomic_ops:
            atomic_op.trainable = False
        if train_bias:
            layer.bias_add_op.trainable = True
        layer.disable_lossy()
        if layer.built:
            layer.enforce_internal_encoding()

        # Prepare conv op's bit reducer
        bits = layer.conv_op.weight_lossy_elements.kernel.bits
        signed = layer.conv_op.weight_lossy_elements.kernel.signed
        ada_round_br = AdaRoundQuantElement(signed=signed, bits=bits)
        var_value = ada_round_br.get_initial_var_value(
            layer.conv_op.final_numeric_kernel,
        )
        var_initializer = tf.initializers.Constant(var_value)
        var = self.add_weight(
            name="adaround_var",
            trainable=True,
            initializer=var_initializer,
            shape=var_value.shape,
        )
        ada_round_br.set_var(var)
        ada_round_br.enable()
        self._original_lossy_elements[layer.full_name] = layer.conv_op.weight_lossy_elements.kernel
        layer.conv_op.weight_lossy_elements.kernel = ada_round_br
        self.ada_round_elements[layer.full_name] = ada_round_br
        if layer.built:
            layer.enforce_internal_encoding()

    def _finalize_layer(self, layer):
        round_val = self.ada_round_elements[layer.full_name].get_fractional_value()
        kscale = layer.conv_op.kernel_scale
        new_kernel = (np.floor(layer.conv_op.kernel / kscale) + round_val) * kscale
        layer.conv_op.kernel.assign(new_kernel)
        layer.conv_op.weight_lossy_elements.kernel = self._original_lossy_elements[layer.full_name]
