import tensorflow as tf


class AdaRoundRegularizer:
    def __init__(
        self,
        max_count: int = 2000,
        b_range: tuple = (20, 2),
        decay_start: float = 0.0,
        warmup: float = 0.0,
        weight: float = 0.01,
    ):
        self.weight = weight  # regularizer factor
        self.loss_start = tf.Variable(max_count * warmup, trainable=False, dtype=tf.float32)
        self._max_count = tf.Variable(max_count, trainable=False, dtype=tf.float32)
        rel_start_decay = warmup + (1 - warmup) * decay_start
        self.start_decay = tf.Variable(tf.math.floor(rel_start_decay * max_count), trainable=False, dtype=tf.float32)
        self.b_range = b_range

    @property
    def max_count(self):
        return self._max_count

    @max_count.setter
    def max_count(self, new_max_count):
        self._max_count.assign(new_max_count)

    def get_beta(self, t):
        """
        Use linear decay function by default.
        """
        return self._liner_temp_decay(t)

    @tf.function
    def _liner_temp_decay(self, time_step):
        time_step = tf.cast(time_step, self.start_decay.dtype)
        start_b, end_b = self.b_range
        if time_step < self.start_decay:
            return tf.cast(start_b, tf.float32)
        else:
            rel_time_step = (time_step - self.start_decay) / (self._max_count - self.start_decay)
            return end_b + (start_b - end_b) * tf.maximum(0.0, (1 - rel_time_step))

    @tf.function
    def __call__(self, round_vals, counter):
        """
        Compute the total loss for adaptive rounding:
        rec_loss is the quadratic output reconstruction loss, round_loss is
        a regularization term to optimize the rounding policy

        :return: total loss function
        """
        counter = tf.cast(counter, round_vals.dtype)
        if counter < self.loss_start:
            b = regularizer_loss = tf.constant(0, dtype=round_vals.dtype)
        else:
            b = self.get_beta(counter)
            expr = tf.abs(round_vals - 0.5)
            expr = 1 - tf.pow(expr * 2, b)
            expr = tf.reduce_sum(expr)
            regularizer_loss = self.weight * expr
        return regularizer_loss
