import copy
from enum import Enum

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    IOUPostprocessLayerParams,
    LayerHandlerType,
    LayerSupportStatus,
    LogitsPostprocessLayerParams,
    NMSProperties,
    PostprocessType,
    ResizePostprocessLayerParams,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType, NMSMetaArchitectures, ResizeMethod
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class PostprocessLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.postprocess
        self._postprocess_type = None
        self._params = {}
        # when `number_of_inputs_supported = None` the number of inputs is unlimited
        self._number_of_inputs_supported = None

    @property
    def postprocess_type(self):
        return self._postprocess_type

    @postprocess_type.setter
    def postprocess_type(self, postprocess_type):
        self._postprocess_type = postprocess_type

    @property
    def params(self):
        return self._params

    @property
    def meta_arch(self):
        return self.params[NMSProperties.META_ARCH.value]

    @meta_arch.setter
    def meta_arch(self, meta_arch):
        self._params[NMSProperties.META_ARCH.value] = meta_arch

    @property
    def max_total_output_proposals(self):
        return self.params[NMSProperties.MAX_TOTAL_OUTPUT_PROPOSALS.value]

    @property
    def image_dims(self):
        return self.params.get(NMSProperties.IMAGE_DIMS.value)

    @image_dims.setter
    def image_dims(self, dims):
        self.params[NMSProperties.IMAGE_DIMS.value] = dims

    @max_total_output_proposals.setter
    def max_total_output_proposals(self, max_total_output_proposals):
        self.params[NMSProperties.MAX_TOTAL_OUTPUT_PROPOSALS.value] = max_total_output_proposals

    @property
    def iou_th(self):
        return self.params[IOUPostprocessLayerParams.IOU_TH.value]

    @iou_th.setter
    def iou_th(self, iou_th):
        self.params[IOUPostprocessLayerParams.IOU_TH.value] = iou_th

    @property
    def nms_scores_th(self):
        field = IOUPostprocessLayerParams.NMS_SCORES_TH.value
        if field in self.params:
            return self.params[field]
        return 0

    @nms_scores_th.setter
    def nms_scores_th(self, nms_scores_th):
        field = IOUPostprocessLayerParams.NMS_SCORES_TH.value
        if field in self.params:
            self.params[field] = nms_scores_th

    @property
    def classes(self):
        return self.params[IOUPostprocessLayerParams.CLASSES.value]

    @classes.setter
    def classes(self, classes):
        self.params[IOUPostprocessLayerParams.CLASSES.value] = classes

    @property
    def max_proposals_per_class(self):
        return self.params[IOUPostprocessLayerParams.MAX_PROPOSALS_PER_CLASS.value]

    @max_proposals_per_class.setter
    def max_proposals_per_class(self, max_proposals_per_class):
        self.params[IOUPostprocessLayerParams.MAX_PROPOSALS_PER_CLASS.value] = max_proposals_per_class

    @property
    def type(self):
        return self.params[LogitsPostprocessLayerParams.TYPE.value]

    @type.setter
    def type(self, type):
        self.params[LogitsPostprocessLayerParams.TYPE.value] = type

    @property
    def axis(self):
        return self.params[LogitsPostprocessLayerParams.AXIS.value]

    @axis.setter
    def axis(self, axis):
        self.params[LogitsPostprocessLayerParams.AXIS.value] = axis

    @property
    def resize_shape(self):
        return self.params[ResizePostprocessLayerParams.RESIZE_SHAPE.value]

    @resize_shape.setter
    def resize_shape(self, resize_shape):
        self.params[ResizePostprocessLayerParams.RESIZE_SHAPE.value] = resize_shape

    @property
    def resize_method(self):
        return self.params[ResizePostprocessLayerParams.RESIZE_METHOD.value]

    @resize_method.setter
    def resize_method(self, resize_method):
        self.params[ResizePostprocessLayerParams.RESIZE_METHOD.value] = resize_method

    @property
    def pixels_mode(self):
        return self.params[ResizePostprocessLayerParams.PIXELS_MODE.value]

    @pixels_mode.setter
    def pixels_mode(self, pixels_mode):
        self.params[ResizePostprocessLayerParams.PIXELS_MODE.value] = pixels_mode

    def __str__(self):
        description = super(LayerWithParams, self).__str__()
        if self._postprocess_type == PostprocessType.NMS:
            description += f" +{self.meta_arch.value}"

        return description

    @property
    def short_description(self):
        base = super().short_description
        if self._postprocess_type == PostprocessType.NMS:
            base += f" +{self.meta_arch.value}"
        return base

    def update_output_shapes(self, **kwargs):
        if self._postprocess_type in [PostprocessType.NMS, PostprocessType.IOU]:
            if self.meta_arch == NMSMetaArchitectures.YOLOV5_SEG:
                # the output order is [y_min, x_min, y_max, x_max, score, class, flattened mask]
                output_shapes = [[-1, 1, int(np.prod(self.image_dims)) + 6, self.max_total_output_proposals]]
            elif "classes" in kwargs and "max_per_class" in kwargs:
                # 5 is for [y_min, x_min, y_max x_max, score]
                output_shapes = [[-1, kwargs["classes"], 5, kwargs["max_per_class"]]]
            elif len(self.output_shapes) != 0:
                output_shapes = self.output_shapes
            else:
                raise UnsupportedModelError(
                    f"Update of output shapes failed at {self.name}. Set the shapes in "
                    "advance or specify the num of classes and max proposals per class.",
                )
        elif self._postprocess_type == PostprocessType.LOGITS and self.type == LayerType.argmax:
            output_shapes = [dim if i != len(self.input_shape) - 1 else 1 for i, dim in enumerate(self.input_shape)]
        elif self._postprocess_type == PostprocessType.RESIZE:
            output_shapes = [[-1, *self.resize_shape, self.input_shape[-1]]]
        elif self.postprocess_type == PostprocessType.BBOX_DECODER:
            if self.meta_arch == NMSMetaArchitectures.YOLOV5:
                features = self.classes + 5
                num_of_anchors = self.input_shapes[-1][-1] // features
                num_of_proposals = num_of_anchors * sum(
                    [input_shape[1] * input_shape[2] for input_shape in self.input_shapes],
                )
            elif self.meta_arch in [NMSMetaArchitectures.YOLOV8, NMSMetaArchitectures.DAMOYOLO]:
                features = self.classes + 4
                num_of_proposals = sum([input_shape[1] * input_shape[2] for input_shape in self.input_shapes[::2]])
            else:
                raise UnsupportedModelError(f"Unsupported meta architecture for bbox decoder: {self.meta_arch}")
            output_shapes = [-1, 1, num_of_proposals, features]
        else:
            output_shapes = [self.input_shapes[-1]]
        self.output_shapes = output_shapes

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_POSTPROCESS
        node.postprocess_type = pb_wrapper.POSTPROCESS_TYPE_TO_PB[self.postprocess_type]

        if self._postprocess_type == PostprocessType.LOGITS:
            node.logits_type = pb_wrapper.LOGITS_TYPE_TO_PB[self.type]
        elif self._postprocess_type in [PostprocessType.NMS, PostprocessType.IOU, PostprocessType.BBOX_DECODER]:
            node.iou_threshold = self.iou_th
            node.classes = self.classes
            node.max_proposals_per_class = self.max_proposals_per_class
            node.nms_scores_th = self.nms_scores_th
        elif self._postprocess_type == PostprocessType.RESIZE:
            node.resize_method = pb_wrapper.RESIZE_METHOD_TYPE_TO_PB[ResizeMethod(self.resize_method)]
            node.resize_h_ratio_list.append(self.output_shape[1] / self.input_shape[1])
            node.resize_w_ratio_list.append(self.output_shape[2] / self.input_shape[2])

        return node

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["postprocess_type"] = self.postprocess_type.value
        hn_params = copy.deepcopy(self.params)
        for key, val in hn_params.items():
            hn_params[key] = val.value if isinstance(val, Enum) else val
        result["params"].update(hn_params)
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.postprocess_type = PostprocessType(hn["params"]["postprocess_type"])
        params = copy.deepcopy(hn["params"])
        if layer.postprocess_type == PostprocessType.NMS:
            params[NMSProperties.META_ARCH.value] = NMSMetaArchitectures(hn["params"][NMSProperties.META_ARCH.value])
        elif layer.postprocess_type == PostprocessType.LOGITS:
            params[LogitsPostprocessLayerParams.TYPE.value] = LayerType(
                hn["params"][LogitsPostprocessLayerParams.TYPE.value],
            )
        layer.params.update(params)
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
