from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class BatchNormLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = True
    # This layer can never run on quantized mode, as it gets replaced with FusedBatchNormLayer
    _REQUIRES_QUANTIZED_WEIGHTS = None
    _IS_REAL_LAYER = True
    _IS_RANK3_SUPPORTED = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.base_batch_norm
        self._bn_info = None
        self._number_of_inputs_supported = 2
        self._ew_connections = []

    @classmethod
    def create(cls, original_name, input_vertex_order, bn_info, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.bn_info = bn_info
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_BATCH_NORM
        return node

    @property
    def bn_info(self):
        return self._bn_info

    @bn_info.setter
    def bn_info(self, bn_info):
        self._bn_info = bn_info

    def get_axes_mask(self, type_of_layer=None):
        return [True, True, False]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected

    @property
    def ew_add_enabled(self):
        return len(self._ew_connections) > 0

    @property
    def ew_add_connections(self):
        return self._ew_connections

    def add_ew_connection(self, other_layer):
        self._ew_connections.append(other_layer)

    def clear_ew_connections(self):
        self._ew_connections = []

    def _is_ew_connection(self, other_layer):
        return other_layer in self._ew_connections

    def sort_inputs(self):
        def sort_function(layer1, layer2):
            ew1 = self._is_ew_connection(layer1)
            ew2 = self._is_ew_connection(layer2)
            if ew1 and (not ew2):
                return 1
            if (not ew1) and ew2:
                return -1
            return 0

        return sort_function

    @property
    def input_shape(self):
        if self.ew_add_enabled:
            return self._input_shapes[0]

        return super().input_shape

    @input_shape.setter
    def input_shape(self, input_shape):
        super(BatchNormLayer, self.__class__).input_shape.fset(self, input_shape)
