import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers import (
    base_hailo_conv,
    hailo_batch_norm,
    hailo_conv,
    hailo_conv_add,
    hailo_depthwise,
)
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class BuildScaleVariables(OptimizationAlgorithm):
    """
    Build scale variables on the model. The algorithm replace the scale vectors by
    keras variables to allow them to train
    """

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="Build Scale Variables", **kwargs)

    def _setup(self):
        pass

    def should_skip_algo(self):
        return False

    def get_algo_config(self):
        return self._model_config.finetune

    def _run_int(self):
        for layer in self._model.layers.values():
            layer._tracker.locked = False  # unlock the layer tracker to allow changes
            self.build_trainable_scales(layer)
            layer._tracker.locked = True  # lock the layer tracker after changes

    def build_trainable_scales(self, layer):
        """
        EXPERIMENTAL - used for scales training.

        for each degree of freedom (DOF), we replace the relevant attribute with a trainable variable,
        initialized to the previous value
        """
        scale_creators = [hailo_conv.HailoConv, hailo_depthwise.HailoDepthwise, hailo_batch_norm.HailoBatchNorm]

        vector_elwa_factor_supported = hailo_conv_add.HailoConvAdd.vector_elwa_factor_supported
        if vector_elwa_factor_supported:
            scale_creators.append(hailo_conv_add.HailoConvAdd)
        # TODO what about the elementwise add..

        # skip layers where the the successor is don't support vector scales
        # TODO: use equiv set to find if the successor can accept vector scales
        for succ in self._model.flow.successors_sorted(layer.full_name):
            succ_handler_type = self._model.layers[succ].get_equalization_handler_type().handler_type
            if succ_handler_type in [
                LayerHandlerType.output,
                LayerHandlerType.unexpected,
                LayerHandlerType.unsupported,
            ]:
                return

        if type(layer) in scale_creators:
            # No constraint on output scale - can change as vector.
            if not isinstance(layer.output_scale, tf.Variable):
                var_scale = layer.add_weight(
                    name="output_scale",
                    shape=layer.output_scale.shape,
                    trainable=True,
                    initializer=tf.keras.initializers.Constant(layer.output_scale),
                )
                layer.set_output_scale(var_scale, 0)

        if isinstance(layer, base_hailo_conv.BaseHailoConv):
            """ Making trainable the single intra-layer DOF - apu rescale factor
                NOTE - we optionally attenuate the learning rate in #channels-dependent way (inspired by LSQ),
                    technically implemented as a reparameterization (making the trainable base variable bigger,
                    thus gradient smaller, thus the actual factor changing quadratically less on each step)
            """
            learning_rate_attenuation = tf.cast(1.0, tf.float32)
            normed_init = layer.act_op.output_factor_by_group * tf.sqrt(learning_rate_attenuation)
            normed_weight = layer.add_weight(
                name="output_factor",
                shape=layer.act_op.output_factor_by_group.shape,
                trainable=True,
                initializer=tf.keras.initializers.Constant(normed_init),
            )
            layer.act_op.output_factor_by_group = normed_weight / tf.sqrt(learning_rate_attenuation)

    def finalize_global_cfg(self, algo_config):
        pass
