import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import ConcatLayer, LayerWithActivation, NMSLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ProposalGeneratorLayer(
    ConcatLayer,
    LayerWithActivation,
):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.proposal_generator
        self._number_of_inputs_supported = 2
        self._proposals_per_output = NMSLayer.BBOX_PER_CHUNK
        self._number_of_coordinates_per_proposal = NMSLayer.BBOX_PARAMETERS - 1
        self._input_division_factor = 1
        self._values_per_proposal = NMSLayer.BBOX_PARAMETERS

    @property
    def input_division_factor(self):
        return self._input_division_factor

    @input_division_factor.setter
    def input_division_factor(self, input_division_factor):
        self._input_division_factor = input_division_factor

    @property
    def proposals_per_output(self):
        return self._proposals_per_output

    @property
    def number_of_coordinates_per_proposal(self):
        return self._number_of_coordinates_per_proposal

    @property
    def values_per_proposal(self):
        return self._values_per_proposal

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super(ConcatLayer, self).to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_PROPOSAL_GENERATOR
        node.input_division_factor = self._input_division_factor
        return node

    def _calc_output_shape(self):
        boxes_shape = self._input_shapes[0]
        classes_shape = self._input_shapes[1]
        anchors = int(boxes_shape[3] / self._number_of_coordinates_per_proposal)
        classes = int(classes_shape[3] / anchors)
        if (
            (boxes_shape[1] != classes_shape[1])
            or (boxes_shape[2] != classes_shape[2])
            or (boxes_shape[3] % self._number_of_coordinates_per_proposal != 0)
            or (classes_shape[3] % anchors != 0)
        ):
            raise UnsupportedModelError(f"Invalid input shapes for {self.full_name_msg}.")
        if self._input_shapes[0][1] % self._input_division_factor != 0:
            raise UnsupportedModelError(f"Invalid input_division_factor for {self.full_name_msg}.")
        if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
            input_width = self.defuse_input_width
        else:
            input_width = boxes_shape[2]

        padded_width = int(np.ceil(input_width / self._proposals_per_output) * self._proposals_per_output)
        total_proposals = anchors * boxes_shape[1] * padded_width
        prop_gen_outputs = int(total_proposals / self._proposals_per_output)
        return [
            boxes_shape[0],
            classes * self._input_division_factor,
            self._number_of_coordinates_per_proposal,
            int(prop_gen_outputs * self._values_per_proposal / self._input_division_factor),
        ]

    @property
    def input_width(self):
        if self.defuse_type == DefuseType.spatial_w:
            return self.defuse_input_width
        return super().input_width

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super(ConcatLayer, self).to_hn(should_get_default_params))
        if self._input_division_factor:
            result["params"]["input_division_factor"] = self._input_division_factor

        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "params" in hn and "input_division_factor" in hn["params"]:
            layer._input_division_factor = hn["params"]["input_division_factor"]
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super(ConcatLayer, cls).from_pb(pb, pb_wrapper)
        layer._input_division_factor = pb.input_division_factor
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def finetune_supported(self):
        return False
