"""
This module implements the AdaRound algorithm for an acceleras model
https://arxiv.org/abs/2004.10568
"""

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.lossy_elements.quant_element import AdaRoundQuantElement
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerAdaRoundConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AdaRoundError
from hailo_model_optimization.algorithms.ada_round.loss_regularizer import AdaRoundRegularizer
from hailo_model_optimization.algorithms.dali_utils import tf_unpad_input


class AdaRoundTrainModel(tf.keras.Model):
    """
    Model that wraps a single layer, the given layer will be trained

    Args:
        layer: acceleras layer to train
        train_bias: whether the bias should be trained as well or not

    """

    def __init__(self, layer, train_bias, logger, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = logger
        self._layer = layer
        self.ada_round_br = None
        self._original_br = None
        self.round_regularizer = None
        self._setup_layer(train_bias)
        self._x_shape = None  # used for DALI slicing / padding
        self._y_shape = None  # used for DALI slicing / padding

    def fit(
        self,
        x,
        iters,
        b_range,
        decay_start,
        warmup,
        weight,
        x_shape,
        y_shape,
        verbose=1,
        callbacks=None,
        validation_split=0,
        validation_data=None,
        initial_epoch=0,
        validation_steps=None,
        validation_freq=1,
    ):
        """Wraps the keras' fit function and initializes a regularizer"""
        # Used in model's call
        self.round_regularizer = AdaRoundRegularizer(
            max_count=iters,
            b_range=b_range,
            decay_start=decay_start,
            warmup=warmup,
            weight=weight,
        )
        self._x_shape = x_shape
        self._y_shape = y_shape

        retval = super().fit(
            x,
            epochs=1,
            verbose=verbose,
            callbacks=callbacks,
            validation_split=validation_split,
            validation_data=validation_data,
            initial_epoch=initial_epoch,
            steps_per_epoch=iters,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
        )

        self._finalize_layer()

        return retval

    def call(self, inputs, training=None, mask=None):
        """
        unpads the input from the dali dataset, updates the round regularizer,
        and calls the trained layer
        """
        round_value = self.ada_round_br.get_fractional_value()
        self.add_loss(
            [self.round_regularizer(round_value, tf.cast(self.optimizer.iterations, dtype=round_value.dtype))]
        )
        inputs_tensors = tf_unpad_input(inputs, self._x_shape)
        outputs = self._layer(inputs_tensors)
        if isinstance(outputs, tf.Tensor):
            outputs = (outputs,)
        if len(outputs) != 1:
            lname = self._layer.full_name
            error = "Multiple outputs is not supported in adaround"
            action = f"please disable adaround for {lname}"
            cmd = LayerAdaRoundConfig.get_command
            feature = LayerAdaRoundConfig.get_feature
            full_command = f"{cmd}({feature}, layers=[{lname}], policy=disabled)"
            raise AdaRoundError(f"{error}, {action}\n{full_command}")
        outputs = tf.stack(outputs, axis=1)
        return outputs

    def _setup_layer(self, train_bias):
        """Prepares the layer for the AdaRound train"""
        # Prepare layer
        for atomic_op in self._layer.atomic_ops:
            atomic_op.trainable = False
        if train_bias:
            self._layer.bias_add_op.trainable = True
        self._layer.disable_lossy()
        if self._layer.built:
            self._layer.enforce_internal_encoding()

        # Prepare conv op's bit reducer
        bits = self._layer.conv_op.weight_lossy_elements.kernel.bits
        ada_round_br = AdaRoundQuantElement(signed=True, bits=bits)
        var_value = ada_round_br.get_initial_var_value(
            self._layer.conv_op.final_numeric_kernel,
        )
        var_initializer = tf.initializers.Constant(var_value)
        var = self.add_weight(
            name="adaround_var",
            trainable=True,
            initializer=var_initializer,
            shape=var_value.shape,
        )
        ada_round_br.set_var(var)
        ada_round_br.enable()
        self._original_br = self._layer.conv_op.weight_lossy_elements.kernel
        self._layer.conv_op.weight_lossy_elements.kernel = ada_round_br
        self.ada_round_br = ada_round_br
        if self._layer.built:
            self._layer.enforce_internal_encoding()

    def _finalize_layer(self):
        round_val = self.ada_round_br.get_fractional_value()
        self._layer.enforce_internal_encoding()  # ensure kernel_scale will be define without scope
        kscale = self._layer.conv_op.kernel_scale
        new_kernel = (np.floor(self._layer.conv_op.kernel / kscale) + round_val) * kscale
        self._layer.conv_op.kernel.assign(new_kernel)
        self._layer.conv_op.weight_lossy_elements.kernel = self._original_br
