"""
Adapted from Keras Distillation example:
   https://keras.io/examples/vision/knowledge_distillation/
BY:
   Author: [Kenneth Borup](https://twitter.com/Kennethborup)
"""

import keras
import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import QFTWriterMode

"""
The custom `Distiller()` class, overrides the `Model` methods `train_step`, `test_step`,
and `compile()`. In order to use the distiller, we need:
- A trained teacher model
- A student model to train
- A student loss function on the difference between student predictions and ground-truth

In the `train_step` method, we perform a forward pass of both the teacher and student,
calculate the distillation loss and mix it with the "normal" student loss (using ground-truth labels),
  using weighting of gt_loss_weight for the latter and 1-gt_loss_weight for the former.
  Only the student weights are updated,
and therefore we only calculate the gradients for the student weights.
In the `test_step` method, we evaluate the student model on the provided dataset.
"""


class Distiller(keras.Model):
    def __init__(self, student: HailoModel, teacher: HailoModel, train_scales: bool = False):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student
        self.train_scales = train_scales
        self.writer = None

    def compile(
        self,
        optimizer,
        metrics,
        supervised_loss_fn,
        distillation_loss_fn,
        loss_layers,
        supervised_proportion=0.0,
        var_freeze_cond=lambda s: False,
        bias_boost=3,
        encoding_boost=1,
        run_eagerly=False,
        stop_graident_at_loss_layers=False,
        wraparound_factor=0.1,
        train_weights=True,
        writer_mode=QFTWriterMode.disabled,
    ):
        """
        Configure the distiller.
        Make sure to call this function after the student model were built.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            supervised_loss_fn: Loss func penalizing student predictions diff vs. ground-truth.
                Should be a 2-arg callback accepting : vector of student outputs and GT (e.g. 1-hot) vector
            distillation_loss_fn: Loss func penalizing student predictions diff vs. teacher predictions
                Should be a 2-arg callback accepting teacher and student, taking output&intermediate tensors from them.
            supervised_proportion: weight to supervised_loss_fn and 1-gt_loss_weight to distillation_loss_fn

        """
        # QFT currently assumes the models are in graph mode
        if stop_graident_at_loss_layers:
            stop_gradient_layers = loss_layers
        else:
            stop_gradient_layers = None

        jit_compile = False  # dont allow XLA compilation for distiller
        self.teacher.compile(run_eagerly=run_eagerly, save_interlayer=loss_layers, jit_compile=jit_compile)
        self.student.compile(
            run_eagerly=run_eagerly,
            save_interlayer=loss_layers,
            stop_gradient_layers=stop_gradient_layers,
            jit_compile=jit_compile,
        )
        super(Distiller, self).compile(
            optimizer=optimizer, metrics=metrics, run_eagerly=run_eagerly, jit_compile=jit_compile
        )
        self.supervised_loss_fn = supervised_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.supervised_proportion = supervised_proportion
        self.var_freeze_cond = var_freeze_cond
        self.bias_boost = bias_boost
        self.encoding_boost = encoding_boost
        self.given_metrics = metrics
        self.wraparound_factor = wraparound_factor
        self._writer_mode = writer_mode

    def build(self, input_shape):
        if not self.teacher.built:
            self.teacher.build(input_shape)
        if not self.student.built:
            self.student.build(input_shape)

        self.trainable_vars = [v._value for v in self.student.trainable_variables if not self.var_freeze_cond(v.name)]
        self.factors = [(self.bias_boost if "bias" in v.name else 1) for v in self.trainable_vars]

    def train_step(self, data):
        # Unpack data
        x, y = data
        # Forward pass of teacher
        _ = self.teacher(x, training=False)

        with tf.GradientTape(persistent=False, watch_accessed_variables=False) as tape:
            tape.watch(self.trainable_vars)
            # Forward pass of student
            # student_predictions = self.student(x, training=True)

            if self.train_scales:
                self.student.enforce_encoding(training=True)
            student_predictions = self.student(x, training=True)

            # Compute losses
            supervised_loss = self.supervised_loss_fn(y, student_predictions) if self.supervised_proportion > 0 else 0
            # distillation_loss = self.distillation_loss_fn(teacher_predictions, student_predictions)
            distillation_loss, components_dict = self.distillation_loss_fn(self.teacher, self.student)
            wraparound_loss = tf.reduce_sum(self.losses)
            loss = (
                self.supervised_proportion * supervised_loss
                + (1 - self.supervised_proportion) * distillation_loss
                + wraparound_loss * self.wraparound_factor
            )
        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_vars, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        factor_gradients = list(map(lambda grad, factor: grad * factor, gradients, self.factors))
        # Update weights
        self.optimizer.apply_gradients(zip(factor_gradients, self.trainable_vars))

        if self._writer_mode in [QFTWriterMode.basic, QFTWriterMode.advanced, QFTWriterMode.expert]:
            self._log_loss(components_dict, distillation_loss, wraparound_loss, supervised_loss)
            self._log_learning_rate(self.optimizer)
        if self._writer_mode in [QFTWriterMode.advanced, QFTWriterMode.expert]:
            self._log_gradients_norm(gradients)
        if self._writer_mode in [QFTWriterMode.expert]:
            self._log_variables(self.trainable_vars)

        # Update the metrics configured in `compile()`.
        if self.given_metrics is not None:
            self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.given_metrics} if self.given_metrics else {}
        results.update({"total_distill_loss": distillation_loss})
        if self.wraparound_factor > 0:
            results.update({"wraparound_loss": wraparound_loss})
        if self.supervised_proportion > 0:
            results.update({"supervised_loss": supervised_loss})
        results.update({"_distill_loss_" + ss: ll for ss, ll in components_dict.items()})
        return results

    def _log_gradients_norm(self, gradients):
        with self.writer.as_default(step=self.optimizer.iterations):
            for grad, tns in zip(gradients, self.trainable_vars):
                grad_norm = tf.norm(grad)
                if "source_" in tns.name:
                    real_name = next(self.student.model_encoding.flow.successors(tns.name.rsplit(":", 1)[0]))
                    name = real_name
                else:
                    name = tns.name

                tf.summary.scalar(name, grad_norm)

    def _log_loss(self, distill_per_layer, total_distill, wraparound, supervised):
        with self.writer.as_default(step=self.optimizer.iterations):
            tf.summary.scalar("loss/distillation", total_distill)
            tf.summary.scalar("loss/wraparound", wraparound)
            tf.summary.scalar("loss/supervised", supervised)
            for lname, loss in distill_per_layer.items():
                tf.summary.scalar(f"loss/distillation_{lname}", loss)

    def _log_learning_rate(self, optimizer):
        with self.writer.as_default(step=self.optimizer.iterations):
            lr = optimizer.learning_rate
            if isinstance(lr, list):
                for index, lr_ in enumerate(lr):
                    tf.summary.scalar(f"learning_rate/{index}", lr_)
            else:
                tf.summary.scalar("learning_rate/0", lr)

    def _log_variables(self, vars):
        with self.writer.as_default(step=self.optimizer.iterations):
            for var in vars:
                if "source_" in var.name:  # from encoding graph
                    real_name = next(self.student.model_encoding.flow.successors(var.name.rsplit(":", 1)[0]))
                    name = real_name
                else:
                    name = var.name.rsplit(":", 1)[0]
                    name_segments = name.split("/")
                    var_type = name_segments[-1]
                    lname = "/".join(name_segments[-4:-2])
                    if var_type == "kernel":
                        native_var = self.teacher.layers[lname].kernel
                        kernel_scale = self.student.layers[lname].conv_op.kernel_scale
                        var = (var - native_var) / kernel_scale
                    if var_type == "bias":
                        native_var = self.teacher.layers[lname].bias
                        bias_scale = self.student.layers[lname].bias_add_op.output_scale
                        var = (var - native_var) / bias_scale
                r1 = name.split("/")
                name = "/".join([r1[0] + "_var", *r1[1:]])
                if np.prod(var.shape) == 1:
                    var_norm = tf.norm(var)
                    tf.summary.scalar(name, var_norm)
                else:
                    tf.summary.histogram(name, var)

    def test_step(self, data):
        # return {}  # TMP - investigating scales-train bug

        # Unpack the data
        x, y = data

        # Compute predictions
        # NOTE: we need infer_encodings for "inter-epoch validation" context,
        #      since had no chance to eagerize after train epoch (it's all inside fit())
        if self.train_scales:
            self.student.enforce_encoding(train_scales=True)
        y_prediction = self.student(x)

        # Calculate the loss
        student_loss = self.supervised_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        if self.supervised_proportion > 0:
            results.update({"student_loss": student_loss})
        return results
