import time

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from verboselogs import VERBOSE


def l2_loss(pred, tgt):
    """Loss function measured in L_2 Norm"""
    return tf.reduce_mean(tf.reduce_sum(tf.pow(pred - tgt, 2), axis=-1))


class AdaroundV1Callbacks(tf.keras.callbacks.Callback):
    """
    Custom callbacks to show adaround progress and log the status at selected times

    Args:
        iters: Max iterations
        warmup_iters: Warmup iterations
        verbose: verbosity level, 0 / 1 to toggle progress bar

    """

    UPDATE_INTERVAL = 500

    def __init__(self, iters, warmup_iters, verbose=1) -> None:
        super().__init__()
        self.iters = int(iters)
        self.warmup_iters = np.ceil(
            warmup_iters,
        )  # ceil to make sure the initial round loss will be printed
        self.verbose = verbose
        self.pbar = None

    def on_train_begin(self, logs=None):
        """Initialize the progress bar"""
        if self.verbose > 0:
            self.pbar = tqdm(total=self.iters, leave=False, dynamic_ncols=True, unit="batches")

    def on_train_batch_end(self, batch, logs=None):
        """Updates the progress bar and print log messages"""
        counter = self.model.optimizer.iterations

        should_update = counter % self.UPDATE_INTERVAL == 0
        self.pbar.update(1)
        if should_update and (self.verbose > 0):
            postfix = dict()
            postfix["round_loss"] = f"{self._get_round_loss():.4f}"
            postfix["b"] = f"{self._get_current_b(counter):.3f}"
            postfix["l2_loss"] = f"{logs['l2_loss']:.4f}"
            self.pbar.set_postfix(postfix)

        if counter == self.warmup_iters and self.model.logger.isEnabledFor(VERBOSE):
            msg = self._get_status_msg(logs)
            self.model.logger.verbose(f"Post warmup: {msg}")
        elif counter == self.iters:
            self.pbar.update(self.iters % self.UPDATE_INTERVAL)
            msg = self._get_status_msg(logs)
            train_time = self.pbar.format_dict["elapsed"]
            time_per_step = train_time * 1000 / self.iters
            formated_time = self.pbar.format_interval(train_time)
            self.model.logger.debug(f"Train finished: {msg}, time: {formated_time}, {time_per_step:.2f} ms/step")

        return super().on_train_batch_end(batch, logs)

    def on_train_end(self, logs=None):
        """Finalize the progress bar"""
        if self.verbose > 0:
            self.pbar.close()

    def _get_round_loss(self):
        """Get the round loss value from the train model"""
        ar_br = self.model.ada_round_br
        round_val = ar_br.get_fractional_value()
        round_loss = self.model.round_regularizer(
            round_val, tf.cast(self.model.optimizer.iterations, dtype=round_val.dtype)
        )
        return round_loss.numpy()

    def _get_current_b(self, counter):
        """Get the b value from the regularizer"""
        annealing_b = self.model.round_regularizer.liner_temp_decay(
            tf.cast(counter, tf.float32),
        )
        return annealing_b.numpy()

    def _get_status_msg(self, logs):
        """Get status message for the log messages"""
        round_loss = self._get_round_loss()
        msg = [
            f"round_loss: {round_loss:.4f}",
            f"l2_loss: {logs['l2_loss']:.4f}",
        ]
        return ", ".join(msg)


class AdaroundV2Callbacks(tf.keras.callbacks.Callback):
    """
    Custom callbacks to show adaround progress and log the status at selected times

    Args:
        iters: Max iterations
        warmup_iters: Warmup iterations
        verbose: verbosity level, 0 / 1 to toggle progress bar

    """

    def __init__(self, iters, wamup_iters, ticks_per_sec=5, log_samples=1) -> None:
        super().__init__()
        self.iters = int(iters)
        self.pbar = None
        self._last_tick = None
        self._count_batch = 0
        self._update_interval = 1 / ticks_per_sec
        self._warmup_iters = np.ceil(wamup_iters)
        self._logged_batches = np.round((self.iters / log_samples) * (np.arange(log_samples) + 1))
        self.history = dict()

    def on_train_begin(self, logs=None):
        """Initialize the progress bar"""
        self.pbar = tqdm(total=self.iters, leave=False, dynamic_ncols=True, unit="batches", desc="Training")
        self._last_tick = time.time()

    def on_train_batch_end(self, batch, logs=None):
        """Updates the progress bar and print log messages"""
        self._update_pbar()

    def _update_pbar(self):
        self._count_batch += 1
        current_tick = time.time()
        if current_tick - self._last_tick > self._update_interval:
            self._last_tick = current_tick
            self.pbar.update(self._count_batch)
            self._count_batch = 0
            metrics_str = []
            for metric in self.model.metrics:
                if metric.name not in {"round_loss", "l2_loss", "annealing_b"}:
                    continue
                metrics_str.append(f"{metric.name}: {metric.result().numpy():0.3e}")
            self.pbar.set_postfix_str(" - ".join(metrics_str))

    def _record_history(self, counter):
        if counter in self._logged_batches:
            for metric in self.model.metrics:
                current_hist = self.history.get(metric.name, [])
                current_hist.append(metric.result().numpy())
                self.history[metric.name] = current_hist

    def on_train_end(self, logs=None):
        """Finalize the progress bar."""
        self.pbar.close()


class DirectMetric(tf.keras.metrics.Metric):
    metric_value: tf.Tensor

    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)
        self.metric_value = self.add_weight(name="metric", initializer="zeros")

    def update_state(self, val):
        self.metric_value.assign(val)

    def result(self):
        return self.metric_value
