import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ReduceMinLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.reduce_min
        self._groups = 1
        self._reduce_axes = [3]

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None, groups=1):
        layer = super().create(original_name, input_vertex_order, output_shapes=output_shapes)
        layer.groups = groups

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.groups = old_layer.groups
        layer.reduce_axes = old_layer.reduce_axes
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["groups"] = self._groups
        result["params"]["reduce_axes"] = self._reduce_axes
        return result

    def _calc_output_shape(self):
        return [dim if i not in self.reduce_axes else self._groups for i, dim in enumerate(self.input_shape)]

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def reduce_axes(self):
        return self._reduce_axes

    @reduce_axes.setter
    def reduce_axes(self, reduce_axes):
        self._reduce_axes = [axis if axis > 0 else axis + len(self.input_shape) for axis in reduce_axes]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
