from functools import partial
from typing import Optional

import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ScheduleType, WarmupStrategy


class ConstantSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate: float, final_learning_rate: Optional[float] = None) -> None:
        self.initial_learning_rate = initial_learning_rate

    def __call__(self, step):
        return self.initial_learning_rate


class LinearSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate: float, final_learning_rate: float, decay_steps: int) -> None:
        self.initial_learning_rate = initial_learning_rate
        self.final_learning_rate = final_learning_rate
        self.slope = (final_learning_rate - initial_learning_rate) / decay_steps

    def __call__(self, step):
        return self.initial_learning_rate + self.slope * tf.cast(step, tf.float32)


class QFTScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    """
    Custom scheduler for the qft algorithm,
    It returns a fixed lr during the warmup period, and uses a custom scheduler during the train period

    Args:
        warmup_lr: learning rate for the warmup stage
        warmup_epochs: number of epochs for warmup stage
        qft_scheduler: the custom scheduler for the train stage
        steps_per_epoch: number of steps per epoch, required to keep track of the epochs
    """

    # TODO: the warmup lr could be custom scheduler as well, instead of hardcoded fixed value

    def __init__(self, warmup_epochs, steps_per_epoch, train_scheduler, warmup_scheduler):
        self.warmup_epochs = warmup_epochs
        self.steps_per_epoch = int(steps_per_epoch)
        self.train_scheduler = train_scheduler
        self.warmup_scheduler = warmup_scheduler

    def __call__(self, step):
        curr_epoch = tf.cast(step // self.steps_per_epoch, tf.float32)
        is_warmup_epoch = curr_epoch < self.warmup_epochs
        train_step = tf.cast(step, tf.int64) - tf.cast(self.warmup_epochs * self.steps_per_epoch, tf.int64)
        lr = tf.cond(is_warmup_epoch, lambda: self.warmup_scheduler(step), lambda: self.train_scheduler(train_step))
        return lr


def get_train_scheduler(
    lr_schedule_type: ScheduleType,
    decay_rate: float,
    steps_per_epoch: int,
    decay_epochs: int,
    dataset_size: int,
    t_mul: float,
):
    """
    Returns a scheduler function based on the given parameters.

    Args:
        lr_schedule_type (ScheduleType): The type of learning rate schedule.
        decay_rate (float): The decay rate for the learning rate schedule.
        steps_per_epoch (int): The number of steps per epoch.
        decay_epochs (int): The number of epochs for decay in the learning rate schedule.
        dataset_size (int): The size of the dataset.
        t_mul (float): The multiplier for the learning rate schedule.

    Returns:
        schedule_gen (function): The scheduler function based on the given parameters.
    """

    if lr_schedule_type == ScheduleType.COSINE_RESTARTS:
        schedule_gen = partial(
            tf.keras.optimizers.schedules.CosineDecayRestarts,
            first_decay_steps=steps_per_epoch * decay_epochs,
            t_mul=t_mul,
            m_mul=decay_rate,
            alpha=1e-7,
        )
    elif lr_schedule_type == ScheduleType.EXPONENTIAL:
        schedule_gen = partial(
            tf.keras.optimizers.schedules.ExponentialDecay, decay_rate=decay_rate, decay_steps=decay_rate * dataset_size
        )
    elif lr_schedule_type == ScheduleType.CONSTANT:
        schedule_gen = ConstantSchedule
    else:
        raise ValueError(f"Unexpected schedule type value {lr_schedule_type}")
    return schedule_gen


def get_warmup_scheduler(warmup_strategy: WarmupStrategy, warmup_epochs: int, steps_per_epoch: int):
    """
    Returns a warmup scheduler based on the given parameters.

    Args:
        warmup_epochs (int): The number of warmup epochs.
        steps_per_epoch (int): The number of steps per epoch.
        warmup_strategy (WarmupStrategy): The warmup strategy to use.

    Returns:
        Callable: The warmup scheduler function.

    Raises:
        ValueError: If an unexpected warmup strategy value is provided.
    """
    if warmup_strategy == WarmupStrategy.CONSTANT:
        return ConstantSchedule
    elif warmup_strategy == WarmupStrategy.GRADUAL:
        decay_steps = warmup_epochs * steps_per_epoch
        return partial(LinearSchedule, decay_steps=decay_steps)
    else:
        raise ValueError(f"Unexpected warmup strategy value {warmup_strategy}")
