import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import MetaArchType


class L2RelChannelwise:
    DEFAULT_SETTINGS = {"localization_boost": 10}

    def __init__(self, meta_arch: MetaArchType) -> None:
        self._channel_factors = self._get_out_channel_factors(meta_arch)

    @tf.function
    def __call__(self, teacher_tens, student_tens):
        native_res_chw = teacher_tens * self._channel_factors
        numeric_res_chw = student_tens * self._channel_factors
        diffnorm = tf.reduce_mean(input_tensor=tf.square(native_res_chw - numeric_res_chw))
        native_norm = tf.reduce_mean(input_tensor=tf.square(native_res_chw))
        return tf.sqrt(diffnorm / native_norm)

    def _get_out_channel_factors(self, meta_arch, settings=None):
        settings = settings if (settings is not None) else self.DEFAULT_SETTINGS
        if meta_arch == MetaArchType.yolo:
            loc_boost = settings["localization_boost"]
            yoloboostloc = np.ones(255)
            yoloboostloc[:5] = loc_boost
            yoloboostloc[85:90] = loc_boost
            yoloboostloc[170:175] = loc_boost
            return tf.constant(yoloboostloc, dtype=tf.float32)
        else:
            raise NotImplementedError(f"Can't _get_out_channel_factors with meta_arch={self.meta_arch.value}")


@tf.function
def l2_loss(teacher_tens, student_tens):
    diffnorm = tf.square(teacher_tens - student_tens)
    return tf.sqrt(tf.reduce_mean(diffnorm))


@tf.function
def l2rel_loss(teacher_tens, student_tens):
    diffnorm = tf.reduce_mean(tf.square(teacher_tens - student_tens))
    native_norm = tf.reduce_mean(tf.square(teacher_tens))
    return tf.cond(tf.reduce_all(native_norm == 0), lambda: native_norm, lambda: tf.sqrt(diffnorm / native_norm))


@tf.function
def ce_loss(teacher_tens, student_tens):
    temperature = 1.5
    ce_loss = tf.keras.losses.categorical_crossentropy(
        tf.nn.softmax(teacher_tens / temperature),
        student_tens / temperature,
        from_logits=True,
    )
    return tf.reduce_mean(ce_loss)
