import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class LayerNormalizationLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.layer_normalization
        self._axes = [3]
        self._epsilon = None
        self._gamma = None
        self._beta = None
        self._rms_norm = False
        self._groups = 1

    @classmethod
    def create(cls, original_name, input_vertex_order, info_dict, rms_norm=False, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.axes = info_dict.get("axes", [3])
        layer.epsilon = info_dict["epsilon"]
        layer.gamma = info_dict["scale"]
        layer.beta = info_dict["B"]
        layer.groups = info_dict.get("groups", 1)
        layer.rms_norm = rms_norm
        return layer

    @property
    def axes(self):
        return self._axes

    @axes.setter
    def axes(self, axes):
        self._axes = axes

    @property
    def epsilon(self):
        return self._epsilon

    @epsilon.setter
    def epsilon(self, epsilon):
        self._epsilon = epsilon

    @property
    def gamma(self):
        return self._gamma

    @gamma.setter
    def gamma(self, gamma):
        self._gamma = gamma

    @property
    def beta(self):
        return self._beta

    @beta.setter
    def beta(self, beta):
        self._beta = beta

    @property
    def rms_norm(self):
        return self._rms_norm

    @rms_norm.setter
    def rms_norm(self, rms_norm):
        self._rms_norm = rms_norm

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op == LayerType.layer_normalization:
            layer.axes = old_layer.axes
            layer.epsilon = old_layer.epsilon
            layer.gamma = old_layer.gamma
            layer.beta = old_layer.beta
            layer.rms_norm = old_layer.rms_norm
            layer.groups = old_layer.groups
        elif old_layer.op == LayerType.l2_normalization:
            layer.axes = old_layer.reduce_axes
            layer.beta = np.array(0)
            layer.epsilon = np.array(0)
            layer.gamma = old_layer.scale
            layer.rms_norm = True
        return layer

    # Get hn file params and assign them to class variables
    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["reduce_axes"] = self.axes
        result["params"]["rms_norm"] = self.rms_norm
        result["params"]["groups"] = self.groups
        return result

    @classmethod
    def from_hn(cls, hn, validate_params_exist=True):
        layer = super().from_hn(hn)
        layer.axes = hn["params"].get("reduce_axes", [3])
        layer.rms_norm = hn["params"].get("rms_norm", False)
        layer.groups = hn["params"].get("groups", 1)

        return layer

    #
    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_LAYERNORM
        return node
