from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType, PaddingType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ReduceL2Layer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = True
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.reduce_l2
        self._reduce_axes = [3]
        self._groups = 1

    @classmethod
    def create(cls, original_name, input_vertex_order, reduce_axes, groups=1, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.reduce_axes = reduce_axes
        layer.groups = groups
        return layer

    def get_width_kernel(self):
        width_dim = len(self.input_shape) - 2
        return self.input_shape[2] if width_dim in self.reduce_axes else 1

    @property
    def strides(self):
        return [1, 1, self.get_width_kernel(), 1]

    @property
    def kernel_shape(self):
        features_dim = len(self.input_shape) - 1
        features_output_dim = self.groups if features_dim in self.reduce_axes else self.input_shape[features_dim]
        return [1, self.get_width_kernel(), int(self.input_shape[features_dim] / self.groups), features_output_dim]

    @property
    def padding(self):
        return PaddingType.same

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    def _calc_output_shape(self):
        # reduces the axes to 1 but if the axis is channels, it will be reduced to the number of groups (1 by default)
        return [
            dim if i not in self.reduce_axes else (1 if i not in [3, -1] else self._groups)
            for i, dim in enumerate(self.input_shape)
        ]

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_REDUCE_L2
        node.reduce_axes.extend(self.reduce_axes)
        node.groups = self._groups
        return node

    def _calc_output_shape(self):
        if len(self.input_shape) == 2:
            return [self.input_shape[0], self._groups]
        if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
            width = self.defuse_input_width
        else:
            width = self.input_shape[2]
        output_shape = [self.input_shape[0], self.input_shape[1], width, self.input_shape[3]]
        for axis in self.reduce_axes:
            output_shape[axis] = 1
        if output_shape[-1] == 1:
            output_shape[-1] = self._groups
        return output_shape

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate)
        if len(self.input_shape) == 2:
            self._reduce_axes = [1]

    @property
    def input_width(self):
        if self.defuse_type == DefuseType.spatial_w:
            return self.defuse_input_width
        return super().input_width

    @property
    def reduce_axes(self):
        return self._reduce_axes

    @reduce_axes.setter
    def reduce_axes(self, reduce_axes):
        self._reduce_axes = reduce_axes

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
