from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class BboxDecoderLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.bbox_decoder
        self._input_list = []
        self._number_of_inputs_supported = 2

    def append_to_input_list(self, inp):
        self._input_list.append(inp)

    @property
    def input_list(self):
        return self._input_list

    @input_list.setter
    def input_list(self, input_list):
        self._input_list = input_list

    @property
    def weights(self):
        num_of_anchors = self.output_features / 4
        # H and W per anchor
        return 2 * num_of_anchors

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BBOX_DECODER
        return node

    def _calc_output_shape(self):
        if len(self._input_list) > 2:
            raise UnsupportedModelError(
                f"Invalid input shapes for {self.full_name_msg}. Can't have more than two inputs",
            )
        if len(self._input_list) == 2:
            first_input = self.pred_layer_output_shape(self._input_list[0])
            second_input = self.pred_layer_output_shape(self._input_list[1])
            for i in range(4):
                if first_input[i] != second_input[i]:
                    raise UnsupportedModelError(
                        f"Invalid input shapes for {self.full_name_msg}. Both inputs needs to "
                        f"be in the same dimensions",
                    )
            if first_input[-1] % 2 != 0 or second_input[-1] % 2 != 0:
                raise UnsupportedModelError(
                    f"Invalid input shapes for {self.full_name_msg}. Both inputs needs to have "
                    f"even number of features",
                )

        features = sum([self.pred_layer_output_shape(in_item)[-1] for in_item in self._input_list])
        return [self.input_shape[0], self.input_shape[1], self.input_shape[2], features]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def finetune_supported(self):
        return False

    @classmethod
    def from_hn(cls, hn):
        return super().from_hn(hn, validate_params_exist=False)
