import tensorflow as tf


class AdaRoundRegularizer:
    def __init__(
        self,
        max_count: int = 2000,
        b_range: tuple = (10, 2),
        decay_start: float = 0.0,
        warmup: float = 0.0,
        weight: float = 0.01,
    ):
        self.weight = weight
        self.loss_start = tf.constant(max_count * warmup, dtype=tf.float32)
        self.max_count = tf.constant(max_count, dtype=tf.float32)
        rel_start_decay = warmup + (1 - warmup) * decay_start
        self.start_decay = tf.constant(rel_start_decay * max_count, dtype=tf.float32)
        self.b_range = tf.constant(b_range, dtype=tf.float32)

    def liner_temp_decay(self, t):
        start_b, end_b = self.b_range[0], self.b_range[1]

        def early():
            return start_b

        def decay():
            rel_t = (t - self.start_decay) / (self.max_count - self.start_decay)
            return end_b + (start_b - end_b) * tf.maximum(0.0, (1 - rel_t))

        return tf.cond(t < self.start_decay, early, decay)

    def __call__(self, round_vals, counter):
        counter = tf.cast(counter, dtype=tf.float32)
        b = self.liner_temp_decay(counter)

        def before_loss_start():
            return tf.constant(0.0, dtype=round_vals.dtype)

        def compute_loss():
            expr = tf.abs(round_vals - 0.5)
            expr = 1 - tf.pow(expr * 2.0, b)
            expr = tf.reduce_sum(expr)
            return self.weight * expr

        regularizer_loss = tf.cond(counter < self.loss_start, before_loss_start, compute_loss)
        return regularizer_loss
