import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ArgmaxLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.argmax
        self._reverse_order = False

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None, reverse_order=False):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer._reverse_order = reverse_order
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._reverse_order = pb.reverse_order
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_ARGMAX
        node.reverse_order = self.reverse_order
        return node

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn, validate_params_exist=False)
        if "params" in hn and "reverse_order" in hn["params"]:
            layer._reverse_order = hn["params"]["reverse_order"]
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["reverse_order"] = self._reverse_order
        return result

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer._reverse_order = old_layer.reverse_order
        return layer

    def _calc_output_shape(self):
        if len(self.input_shape) == 2:
            return [self.input_shape[0], 1]
        if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
            width = self.defuse_input_width
        else:
            width = self.input_shape[2]
        return [self.input_shape[0], self.input_shape[1], width, 1]

    @property
    def reverse_order(self):
        return self._reverse_order

    @property
    def input_width(self):
        if self.defuse_type == DefuseType.spatial_w:
            return self.defuse_input_width
        return super().input_width

    @property
    def finetune_supported(self):
        return False

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
