from hailo_model_optimization.acceleras.utils.acceleras_definitions import LossType
from hailo_model_optimization.algorithms.finetune.losses.loss_functions import (
    L2RelChannelwise,
    ce_loss,
    l2_loss,
    l2rel_loss,
)


def resolve_loss_type(loss_type):
    return LossType(loss_type)


def loss_factory(loss_type, loss_kwargs=None):
    if loss_kwargs is None:
        loss_kwargs = dict()
    loss_type = LossType(loss_type)
    if loss_type == LossType.CROSS_ENTROPY:
        loss_func = ce_loss
    elif loss_type == LossType.L2:
        loss_func = l2_loss
    elif loss_type == LossType.L2REL:
        loss_func = l2rel_loss
    elif loss_type == LossType.L2REL_CHW:
        loss_func = L2RelChannelwise(meta_arch=loss_kwargs["meta_arch"])
    else:
        raise ValueError(f"Invalid loss type: {loss_type}")
    return loss_func
