from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class InstanceNormalizationLayer(LayerWithParams):
    def __init__(self):
        super().__init__()
        self._op = LayerType.instance_normalization
        self._epsilon = None
        self._gamma = None
        self._beta = None
        self._groups = None
        self._axes = None

    @classmethod
    def create(cls, original_name, input_vertex_order, info_dict, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.epsilon = info_dict["epsilon"]
        layer.gamma = info_dict["scale"]
        layer.beta = info_dict["B"]
        layer.groups = info_dict["groups"]
        layer.axes = info_dict["axes"]
        return layer

    @property
    def epsilon(self):
        return self._epsilon

    @epsilon.setter
    def epsilon(self, epsilon):
        self._epsilon = epsilon

    @property
    def gamma(self):
        return self._gamma

    @gamma.setter
    def gamma(self, gamma):
        self._gamma = gamma

    @property
    def beta(self):
        return self._beta

    @beta.setter
    def beta(self, beta):
        self._beta = beta

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def axes(self):
        return self._axes

    @axes.setter
    def axes(self, axes):
        self._axes = axes

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=True)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
