from collections import OrderedDict
from dataclasses import dataclass
from typing import Callable, Dict, List, Union

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LossType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.algorithms.finetune.losses.loss_factory import loss_factory


@dataclass
class LossInfo:
    loss_func: Callable
    loss_factor: float


class DeepDistillLoss:
    """
    Handles construction of the final segment of the training graph ( in compare() )
    computing "reconstruction loss" from tensors of teacher and student,
     according to the loss configuration supplied in constructor.

    Args:
      loss_layer_names (list of istr): Names of layers to be used for teacher-student losses.
                        Names to be given in Hailo HN notation, s.a. *conv20*, *fc1*, etc.
                        Default: the output nodes of the net (the part described by the HN)

      loss_types (list of :class:`LOSS_TYPES`, str, or callable, of same length as *loss_layer_names*):
          The teacher-student bivariate loss function types to apply on the native and numeric outputs
           of the respective loss layers specified byloss_layer_names.
          For example, "ce" (standing for "cross-entropy") is typically used for the classification head(s).
          If callable is passed, it will be used when calculating the teacher-student loss given the two tensors of
          native and quantized nets at the given layer. This a handy generic way to extend the functionality to any
          bivariate loss. Default: the ``def_loss_type`` arg (or its default, "l2rel") for each of the layers.

      loss_factors (list of int, of same length as *loss_layer_names*):
          Weighting factors of the loss components created by applying loss_types-defined bi-variate functions
            on native/numeric tensors produced by respective loss_layer_names , to arrive at the total loss.
          Default to 1 for all components.

      def_loss_type: The default loss type to use if ``loss_types`` is not given

    """

    def __init__(
        self,
        loss_layers: List[str],
        loss_types: List[Union[str, LossType, Callable]] = None,
        loss_factors: List[float] = None,
        default_loss: Union[str, LossType] = "l2rel",
        **loss_kwargs,
    ):
        if loss_layers is None:
            raise AccelerasImplementationError("loss_lnames is required for Acclereas finetune")

        self.losses_info: Dict[str, LossInfo] = dict()
        for index, loss_layer in enumerate(loss_layers):
            loss_type = default_loss if loss_types is None else loss_types[index]
            if callable(loss_type):
                loss_func = loss_type
            else:
                loss_type = LossType(loss_type)
                meta_arch = loss_kwargs.get("meta_arch")
                if loss_type == LossType.L2REL_CHW and meta_arch is None:
                    raise ValueError("Can't do channelwise-weighted loss without meta_arch")

                loss_func = loss_factory(loss_type, loss_kwargs)

            loss_factor = 1 if loss_factors is None else loss_factors[index]
            self.losses_info[loss_layer] = LossInfo(loss_func, loss_factor)

    def compare(self, teacher, student):
        """
        invoked by QftRunner to build and return distillation loss tensor,
        from tensors of teacher and student models
        TODO type hints.
        """
        losses = OrderedDict()
        total_loss = 0

        for layer_name, loss_info in self.losses_info.items():
            teacher_tens = teacher.interlayer_tensors[layer_name]
            student_tens = student.interlayer_tensors[layer_name]
            if teacher.layers[layer_name].num_outputs == 1:
                teacher_tens = [teacher_tens]
                student_tens = [student_tens]
            losses[layer_name] = 0
            for t_tensor, s_tensor in zip(teacher_tens, student_tens):
                losses[layer_name] += loss_info.loss_func(t_tensor, s_tensor)
            total_loss += losses[layer_name] * loss_info.loss_factor

        return total_loss, losses
