from functools import partial
from typing import List, Union

import keras
from typeguard import typechecked

from hailo_model_optimization.acceleras.utils.acceleras_definitions import Optimizer


# This class was copied from Tensorflow addons (tfa). EOL of tfa is May 2024.
class MultiOptimizer(keras.optimizers.Optimizer):
    """Multi Optimizer Wrapper for Discriminative Layer Training.

    Creates a wrapper around a set of instantiated optimizer layer pairs.
    Generally useful for transfer learning of deep networks.

    Each optimizer will optimize only the weights associated with its paired layer.
    This can be used to implement discriminative layer training by assigning
    different learning rates to each optimizer layer pair.
    `(keras.optimizers.Optimizer, List[keras.layers.Layer])` pairs are also supported.
    Please note that the layers must be instantiated before instantiating the optimizer.

    Args:
        optimizers_and_layers: a list of tuples of an optimizer and a layer or model.
            Each tuple should contain exactly 1 instantiated optimizer and 1 object that
            subclasses `keras.Model`, `keras.Sequential` or `keras.layers.Layer`.
            Nested layers and models will be automatically discovered.
            Alternatively, in place of a single layer, you can pass a list of layers.
        optimizer_specs: specialized list for serialization.
            Should be left as None for almost all cases.
            If you are loading a serialized version of this optimizer,
            please use `keras.models.load_model` after saving a model compiled with this optimizer.

    Usage:

    >>> model = keras.Sequential([
    ...     keras.Input(shape=(4,)),
    ...     keras.layers.Dense(8),
    ...     keras.layers.Dense(16),
    ...     keras.layers.Dense(32),
    ... ])
    >>> optimizers = [
    ...     keras.optimizers.Adam(learning_rate=1e-4),
    ...     keras.optimizers.Adam(learning_rate=1e-2)
    ... ]
    >>> optimizers_and_layers = [(optimizers[0], model.layers[0]), (optimizers[1], model.layers[1:])]
    >>> optimizer = MultiOptimizer(optimizers_and_layers)
    >>> model.compile(optimizer=optimizer, loss="mse")

    Reference:
        - [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146)
        - [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440)

    Note: Currently, `MultiOptimizer` does not support callbacks that modify optimizers.
        However, you can instantiate optimizer layer pairs with
        `keras.optimizers.schedules.LearningRateSchedule`
        instead of a static learning rate.

    This code should function on CPU, GPU, and TPU. Apply with `tf.distribute.Strategy().scope()` context as you
    would with any other optimizer.

    This code assumes the names of the variables are unique, and that the layer weights has been initialized.
    """

    @typechecked
    def __init__(
        self,
        optimizers_and_layers: Union[list, None] = None,
        optimizer_specs: Union[list, None] = None,
        name: str = "MultiOptimizer",
        **kwargs,
    ):
        super(MultiOptimizer, self).__init__(
            learning_rate=float(optimizers_and_layers[0][0].learning_rate), name=name, **kwargs
        )

        if optimizer_specs is None and optimizers_and_layers is not None:
            self.optimizer_specs = [
                self.create_optimizer_spec(optimizer, layers_or_model)
                for optimizer, layers_or_model in optimizers_and_layers
            ]

        elif optimizer_specs is not None and optimizers_and_layers is None:
            self.optimizer_specs = [self.maybe_initialize_optimizer_spec(spec) for spec in optimizer_specs]

        else:
            raise RuntimeError("Must specify one of `optimizers_and_layers` or `optimizer_specs`.")

        self.var_to_tf_var_mapping = {}

    def apply_gradients(self, grads_and_vars, **kwargs):
        """Wrapped apply_gradient method.

        Returns an operation to be executed.
        """

        for spec in self.optimizer_specs:
            spec["gv"] = []

        for grad, var in tuple(grads_and_vars):
            for spec in self.optimizer_specs:
                for spec_var in spec["weights"]:
                    if var is spec_var._value:
                        spec["gv"].append((grad, var))

        for spec in self.optimizer_specs:
            if spec["gv"]:
                spec["optimizer"].apply_gradients(spec["gv"], **kwargs)
        return self.iterations.assign_add(1)

    def get_config(self):
        config = super(MultiOptimizer, self).get_config()
        optimizer_specs_without_gv = []
        for optimizer_spec in self.optimizer_specs:
            optimizer_specs_without_gv.append(
                {
                    "optimizer": optimizer_spec["optimizer"],
                    "weights": optimizer_spec["weights"],
                }
            )
        config.update({"optimizer_specs": optimizer_specs_without_gv})
        return config

    @classmethod
    def create_optimizer_spec(
        cls,
        optimizer: keras.optimizers.Optimizer,
        layers_or_model: Union[
            keras.Model,
            keras.Sequential,
            keras.layers.Layer,
            List[keras.layers.Layer],
        ],
    ):
        """Creates a serializable optimizer spec.

        The name of each variable is used rather than `var.ref()` to enable serialization and deserialization.
        """
        if isinstance(layers_or_model, list):
            is_built = all(sublayer.built for sublayer in layers_or_model)
            weights = [var for sublayer in layers_or_model for var in sublayer.weights]
        else:
            is_built = layers_or_model.built
            weights = [var for var in layers_or_model.weights]

        if is_built is False:
            raise ValueError("Model is not built. Please build the model before creating the optimizer.")

        return {
            "optimizer": optimizer,
            "weights": weights,
        }

    @classmethod
    def maybe_initialize_optimizer_spec(cls, optimizer_spec):
        if isinstance(optimizer_spec["optimizer"], dict):
            optimizer_spec["optimizer"] = keras.optimizers.deserialize(optimizer_spec["optimizer"])

        return optimizer_spec

    def __repr__(self):
        return f"Multi Optimizer with {len(self.optimizer_specs)} optimizer layer pairs"

    @property
    def learning_rate(self):
        return [spec["optimizer"].learning_rate for spec in self.optimizer_specs]

    @learning_rate.setter
    def learning_rate(self):
        raise NotImplementedError("Can't set learning rate of MultiOptimizer.")


def get_optimizer_gen(optimzer: Optimizer):
    if optimzer == Optimizer.adam:  # NOTE: this is actually the default case (via default config)
        optimizer_gen = partial(keras.optimizers.Adam, epsilon=0.01)
    elif optimzer == Optimizer.sgd:
        optimizer_gen = partial(keras.optimizers.SGD, momentum=0.9)
    else:
        raise ValueError(f"Unexpected optimizer value {optimzer}")
    return optimizer_gen
