import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class SoftmaxLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.softmax
        self._groups = 1
        self._axis = -1
        self._additive_mask = None
        self._number_of_inputs_supported = 2

    @classmethod
    def create(cls, original_name, input_vertex_order, groups=1, axis=-1, additive_mask=None, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.groups = groups
        layer.axis = axis
        layer.additive_mask = additive_mask
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.groups = pb.groups
        layer.axis = pb.reduce_axes[0]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.groups = old_layer.groups
        layer.axis = old_layer.axis
        layer.additive_mask = old_layer.additive_mask
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        if len(self._input_shapes[0]) == 2:
            node.input_shapes[0].height, node.input_shapes[0].width, node.input_shapes[0].features = (
                1,
                1,
                self.input_shapes[0][1],
            )
            node.output_shapes[0].height, node.output_shapes[0].width, node.output_shapes[0].features = (
                1,
                1,
                self.output_shapes[0][1],
            )
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_SOFTMAX
        node.groups = self._groups
        node.reduce_axes.append(self._axis)
        return node

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn, validate_params_exist=False)
        if "params" in hn and "groups" in hn["params"]:
            layer.groups = hn["params"]["groups"]
        if "params" in hn and "logits_axis" in hn["params"]:
            layer.axis = hn["params"]["logits_axis"]
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["groups"] = self._groups
        result["params"]["logits_axis"] = self._axis
        return result

    @property
    def input_height(self):
        return 1 if len(self._input_shapes[0]) == 2 else self._input_shapes[0][1]

    @property
    def input_width(self):
        return 1 if len(self._input_shapes[0]) == 2 else self._input_shapes[0][2]

    @property
    def input_features(self):
        return self._get_shape_single_dim(self._input_shapes, -1)

    @property
    def output_features(self):
        return self._get_shape_single_dim(self._input_shapes, -1)

    @property
    def finetune_supported(self):
        return False

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def axis(self):
        return self._axis

    @axis.setter
    def axis(self, axis):
        self._axis = axis

    @property
    def additive_mask(self):
        return self._additive_mask

    @additive_mask.setter
    def additive_mask(self, additive_mask):
        self._additive_mask = additive_mask

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
