from hailo_model_optimization.acceleras.atomic_ops.normalization_op import NormalizationOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)


class HailoLayerNormalization(BaseHailoSingleAtomic):
    """
    native layer normalization layer - see NormalizationOp for more details
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w16_a16,
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.LAYER_NORM

    OP_NAME = "normalization_op"

    def __init__(self, name: str, rms_norm=False, reduce_axes=(3,), groups=1, logger=None, **kwargs):
        layer_norm = NormalizationOp(
            f"{name}/{self.OP_NAME}", logger=logger, rms_norm=rms_norm, reduce_axes=reduce_axes, groups=groups
        )
        super().__init__(name=name, core_op=layer_norm, logger=logger, **kwargs)

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w16

    @property
    def rms_norm(self):
        return self.atomic_op.rms_norm

    @property
    def epsilon(self):
        return self.atomic_op.epsilon

    @property
    def groups(self):
        return self.atomic_op.groups

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element.get("params", {})
        groups = params.get("groups", 1)
        rms_norm = params.get("rms_norm", False)
        reduce_axes = params.get("reduce_axes", (3))
        layer = cls(name=lname, rms_norm=rms_norm, reduce_axes=reduce_axes, groups=groups, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        if self.rms_norm:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        else:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def homogeneous(self):
        return False

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {
            OptimizationTarget.MERCURY,
            OptimizationTarget.SAGE,
            OptimizationTarget.PLUTO,
            OptimizationTarget.MARS,
        }:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE  # by default same as emulator support
        elif arch in {OptimizationTarget.EMULATION}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE
        else:
            supported_precision_mode = set()
        return supported_precision_mode

    @property
    def reduce_axes(self):
        return self.atomic_op.reduce_axes
