import copy

import numpy as np
import tensorflow as tf
from pydantic.v1 import BaseModel

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BBoxDecodersInfo,
    BiasMode,
    LayerType,
    LogitsPostprocessLayerParams,
    NMSOnCpuMetaArchitectures,
    NMSProperties,
    PostprocessType,
    ResizeBilinearPixelsMode,
    ResizeMethod,
    ResizePostprocessLayerParams,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasException,
    AccelerasImplementationError,
)


class PostProcessConfig(BaseModel):
    background_removal: bool
    background_removal_index: int
    bbox_decoders_info: list
    bbox_dimensions_scale_factor: float
    centers_scale_factor: float
    classes: int
    image_dims: list
    mask_threshold: float
    max_proposals_per_class: int
    max_total_proposals: int
    nms_iou_th: float
    nms_scores_th: float
    proto_info: dict
    regression_length: int
    dfl_on_nn_core: bool


class HailoPostprocess(BaseHailoNonNNCoreLayer):
    """
    Represents `postprocess` layer in the hn

    Args:
        name: layer name
        params: parameters of the post-process
        logger: the logger for the class

    """

    _hn_type = LayerType.POSTPROCESS
    SUPPORT_EXPORT_HW_PARAMS = False

    def __init__(self, name: str, input_shapes: list, output_shapes: list, params: dict, logger=None, **kwargs):
        super().__init__(name=name, input_shapes=input_shapes, output_shapes=output_shapes, logger=logger, **kwargs)
        self._params = params
        self.config: PostProcessConfig = None
        self._trainable = False

    @property
    def is_precision_transparent(self) -> bool:
        return False

    @property
    def meta_arch(self):
        return self._params[NMSProperties.META_ARCH.value]

    @property
    def postprocess_type(self):
        return self._params["postprocess_type"]

    @property
    def config(self):
        return self._config

    @config.setter
    def config(self, config):
        self._config = config

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        input_shapes = hn_element.get("input_shapes", [])
        output_shapes = hn_element.get("output_shapes", [])
        params = copy.deepcopy(hn_element.get("params", dict()))

        params["postprocess_type"] = PostprocessType(params["postprocess_type"])
        if params["postprocess_type"] in [PostprocessType.NMS, PostprocessType.BBOX_DECODER, PostprocessType.IOU]:
            params[NMSProperties.META_ARCH.value] = NMSOnCpuMetaArchitectures(params[NMSProperties.META_ARCH.value])
        layer = cls(name=lname, input_shapes=input_shapes, output_shapes=output_shapes, params=params, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def get_precision_mode(self):
        return self._precision_mode

    def get_layer_precision_config(self):
        try:
            quantization_params = LayerPrecisionConfig(
                bias_mode=BiasMode.double_scale_decomposition,
                precision_mode=self.get_precision_mode(),
                quantization_groups=1,
                signed_output=False,
            )
            return quantization_params.raw_dict()
        except AttributeError:
            return {}

    def to_hn(self, out_degree=None):
        hn_element = super().to_hn(out_degree=out_degree)
        hn_element["quantization_params"] = self.get_layer_precision_config()
        self.hn_element["quantization_params"].update(hn_element["quantization_params"])
        return hn_element

    def call_core(self, inputs, training=False, **kwargs):
        if self.postprocess_type in [PostprocessType.NMS, PostprocessType.BBOX_DECODER]:
            return self.bbox_decoding_and_nms_call(
                inputs,
                is_bbox_decoding_only=self.postprocess_type == PostprocessType.BBOX_DECODER,
            )
        elif self.postprocess_type == PostprocessType.IOU:
            return self.iou_call(inputs)
        elif self.postprocess_type == PostprocessType.LOGITS:
            return self.logits_call(inputs)
        elif self.postprocess_type == PostprocessType.RESIZE:
            return self.resize_call(inputs)
        else:
            raise AccelerasImplementationError(f"Acceleras does not support layer {self.postprocess_type.value}")

    def import_acceleras(self, params):
        pass

    def iou_call(self, inputs):
        # running additional nms
        inputs = tf.transpose(inputs, [0, 3, 1, 2])
        detection_boxes = inputs[:, :, :, :4]
        detection_scores = tf.squeeze(inputs[:, :, :, 4:], axis=3)

        return self._combined_non_max_suppression(detection_boxes, detection_scores)

    def bbox_decoding_and_nms_call(self, inputs, is_bbox_decoding_only):
        if self.meta_arch in [NMSOnCpuMetaArchitectures.YOLOV5, NMSOnCpuMetaArchitectures.YOLOX]:
            # concatenating inputs in case of yolox
            if self.meta_arch == NMSOnCpuMetaArchitectures.YOLOX:
                inputs = self.concat_yolox_inputs(inputs)
            decoded_bboxes, detection_score = self.yolov5_and_yolox_decoding_call(
                inputs,
                META_ARCH_TO_BBOX_DECODER_FUNC[self.meta_arch],
            )
        elif self.meta_arch == NMSOnCpuMetaArchitectures.YOLOV5_SEG:
            return self.yolo_seg_nms_call(inputs)
        elif self.meta_arch == NMSOnCpuMetaArchitectures.YOLOV8:
            decoded_bboxes, detection_score = self.yolov8_decoding_call(inputs, offsets=[0.5, 0.5])
        elif self.meta_arch == NMSOnCpuMetaArchitectures.DAMOYOLO:
            # damoyolo meta-architecture has the same bbox decoding as yolov8 but with no offsets
            decoded_bboxes, detection_score = self.yolov8_decoding_call(inputs, offsets=[0.0, 0.0])
        elif self.meta_arch == NMSOnCpuMetaArchitectures.SSD:
            decoded_bboxes, detection_score = self.ssd_decoding_call(inputs)
        else:
            raise AccelerasImplementationError(
                f"The meta-architecture {self.meta_arch.value} " "is currently not supported by Acceleras.",
            )

        if is_bbox_decoding_only:
            decoded_bboxes = tf.squeeze(decoded_bboxes, axis=2)
            return tf.expand_dims(tf.concat([decoded_bboxes, detection_score], axis=-1), axis=1)

        return self._combined_non_max_suppression(decoded_bboxes, detection_score)

    def yolov5_and_yolox_decoding_call(self, inputs, bbox_decoding_func):
        # creating anchor list with the following structure -> [[w1, h1], [w2, h2], [w3, h3]], * num of branches] h's and w's as the number of the anchors
        anchors = []
        strides = []
        for bbox in self.config.bbox_decoders_info:
            bbox_anchors = []
            for w, h in zip(bbox.get("w", [1]), bbox.get("h", [1])):
                bbox_anchors.append([w, h])
            anchors.append(bbox_anchors)
            strides.append(bbox[BBoxDecodersInfo.STRIDE.value])
        num_classes = self.config.classes

        decoded_bbox = []
        decoded_class_scores = []
        decoded_class_objectness = []
        # iterating over the input layers each input for different stride
        for stride, branch_anchors, input_layer in zip(strides, anchors, inputs):
            branch_anchors = tf.convert_to_tensor(branch_anchors)  # dims [3, 2]
            branch_anchors = tf.expand_dims(tf.expand_dims(branch_anchors, 0), 0)  # dims [1, 1, 3, 2]
            H_output, W_output = input_layer.shape[1], input_layer.shape[2]

            grid_x, grid_y = tf.meshgrid(
                tf.range(0, W_output, dtype=tf.int32),
                tf.range(0, H_output, dtype=tf.int32),
            )  # dims [H_output, W_output]
            offsets = tf.keras.layers.Concatenate(axis=-1)(
                [grid_x[:, :, tf.newaxis], grid_y[:, :, tf.newaxis]],
            )  # dims [H_input, W_input, 2]
            offsets = tf.expand_dims(tf.expand_dims(offsets, 0), 0)  # dims [1, 1, H_input, W_input, 2]

            num_anchors = branch_anchors.shape[2]
            f_out = 5 + num_classes  # 2 for scale, 2 for centers 1 for obj and 'num_classes'
            input_layer = tf.transpose(input_layer, perm=[0, 3, 1, 2])  # dims [N, f_out*num_anchors, H, W]
            input_layer = tf.reshape(
                input_layer,
                (-1, f_out * num_anchors, H_output * W_output),
            )  # dims [N, f_out*num_anchors, H*W]
            input_layer = tf.transpose(input_layer, perm=[0, 2, 1])  # dims [N, H*W, f_out*num_anchors]
            input_layer = tf.reshape(
                input_layer,
                (-1, H_output * W_output, num_anchors, f_out),
            )  # dims [N, H*W, num_anchors, f_out]

            # preparing outputs
            proposed_bbox_centers = input_layer[:, :, :, 0:2]  # dims [N, H*W, num_anchors, 2]
            proposed_bbox_scales = input_layer[:, :, :, 2:4]  # dims [N, H*W, num_anchors, 2]
            offsets = tf.reshape(offsets, (1, -1, 1, 2))  # dims [1, H*W, 1, 2]

            # decoding bboxes
            box_centers, box_scales = bbox_decoding_func(
                proposed_bbox_centers,
                proposed_bbox_scales,
                offsets,
                branch_anchors,
                stride,
            )
            # calculates scores
            confidence = input_layer[:, :, :, 4:5]  # dims [batch_size, H*W, num_anchors, 1]
            class_prediction = input_layer[:, :, :, 5:]  # dims [batch_size, H*W, num_anchors, num_classes]
            class_scores = (
                confidence * class_prediction if self.postprocess_type == PostprocessType.NMS else class_prediction
            )  # dims [batch_size, H*W, num_classes]

            # changing the format of proposed_bbox_centers from (y_center, x_center) to (bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max)
            wh = box_scales / 2.0
            # dims [batch_size, H*W, num_anchors, 4]
            bbox = tf.keras.layers.Concatenate(axis=-1)([box_centers - wh, box_centers + wh])

            # flatten num_anchor dimension
            flatten_bbox = tf.reshape(
                bbox,
                (-1, H_output * W_output * num_anchors, 1, 4),
            )  # dims [N, num_detections = (h*w*num_anchors), 1, 4]
            # dims [N, num_detections = (h*w*num_anchors), 80]
            flatten_class_scores = tf.reshape(class_scores, (-1, H_output * W_output * num_anchors, num_classes))

            # normalizing and switching to format of bbox_y_min, bbox_x_min, bbox_y_max, bbox_x_max
            bbox_y_min = flatten_bbox[:, :, :, 1:2] / self.config.image_dims[0]
            bbox_x_min = flatten_bbox[:, :, :, 0:1] / self.config.image_dims[1]
            bbox_y_max = flatten_bbox[:, :, :, 3:] / self.config.image_dims[0]
            bbox_x_max = flatten_bbox[:, :, :, 2:3] / self.config.image_dims[1]
            normalized_bbox = tf.keras.layers.Concatenate(axis=-1)([bbox_y_min, bbox_x_min, bbox_y_max, bbox_x_max])

            decoded_bbox.append(normalized_bbox)
            decoded_class_scores.append(flatten_class_scores)

            if self.postprocess_type == PostprocessType.BBOX_DECODER:
                flatten_class_confidence = tf.reshape(
                    confidence,
                    (-1, H_output * W_output * num_anchors, 1),
                )  # dims [N, num_detections = (h*w*num_anchors), 1]
                decoded_class_objectness.append(flatten_class_confidence)

        concatenated_detection_boxes = tf.keras.layers.Concatenate(axis=1)(decoded_bbox)
        concatenated_scores = tf.keras.layers.Concatenate(axis=1)(decoded_class_scores)
        if self.postprocess_type == PostprocessType.BBOX_DECODER:
            concatenated_confidence = tf.keras.layers.Concatenate(axis=1)(decoded_class_objectness)
            concatenated_scores = tf.keras.layers.Concatenate(axis=-1)([concatenated_confidence, concatenated_scores])

        return concatenated_detection_boxes, concatenated_scores

    def yolo_seg_nms_call(self, inputs):
        return tf.numpy_function(self.yolov5_seg_postprocess, inputs, np.float32)

    def yolov5_seg_postprocess(self, in1, in2, in3, in4):
        inputs = [in1, in2, in3, in4]
        num_of_masks = self.config.proto_info["number"]
        protos = [inp for inp in inputs if inp.shape[3] == num_of_masks][0]
        outputs = []
        for in_data in [inp for inp in inputs if inp.shape[3] != num_of_masks]:
            stride = int(self.config.image_dims[0] / in_data.shape[1])
            bbox = [
                bbox_decoder
                for bbox_decoder in self.config.bbox_decoders_info
                if bbox_decoder[BBoxDecodersInfo.STRIDE.value] == stride
            ][0]
            anchors = (
                np.array(
                    [
                        item
                        for pair in zip(bbox[BBoxDecodersInfo.W.value], bbox[BBoxDecodersInfo.H.value])
                        for item in pair
                    ],
                )
                / stride
            )
            decoded_info = _yolov5_decoding(stride, anchors, self.config.classes, in_data)
            outputs.append(decoded_info)

        outputs = np.concatenate(outputs, 1)  # (BS, num_proposals, 117)
        outputs = non_max_suppression_yolo_seg(
            outputs,
            self.config.nms_scores_th,
            self.config.nms_iou_th,
            n_masks=num_of_masks,
            max_det=self.config.max_total_proposals,
        )
        for batch_idx, output in enumerate(outputs):
            boxes = output[0]
            masks = output[1]
            proto = protos[batch_idx]
            output[1] = process_mask(
                proto,
                masks,
                boxes,
                self.config.image_dims,
                mask_threshold=self.config.mask_threshold,
            )
        return self._organize_yolov5_seg_output(outputs)

    def yolov8_decoding_call(self, inputs, offsets):
        regression_len = self.config.regression_length
        img_dims = self.config.image_dims
        strides = [bbox_decoder[BBoxDecodersInfo.STRIDE.value] for bbox_decoder in self.config.bbox_decoders_info]

        raw_boxes = [layer for layer in inputs if layer.shape[-1] == 4 * regression_len]
        raw_scores = [layer for layer in inputs if layer.shape[-1] == self.config.classes]

        decoded_bboxes = None
        detection_scores = None
        for box_distribute, stride, scores in zip(raw_boxes, strides, raw_scores):
            # create grid
            h, w = box_distribute.shape[1:3]
            # damoyolo meta-architecture has the same bbox decoding as yolov8 but with no offsets
            offset_x, offset_y = offsets
            grid_x = np.arange(w) + offset_x
            grid_y = np.arange(h) + offset_y
            grid_x, grid_y = np.meshgrid(grid_x, grid_y)
            center_row = (grid_y.flatten()) * stride
            center_col = (grid_x.flatten()) * stride
            centers = np.stack((center_col, center_row, center_col, center_row), axis=1)

            scores = tf.reshape(scores, (-1, h * w, self.config.classes))
            if not self.config.dfl_on_nn_core:
                # box distribution to distance
                bboxes = tf.reshape(box_distribute, (-1, h * w, 4, regression_len))
                bboxes = tf.nn.softmax(bboxes, axis=-1)
                bboxes = bboxes * tf.reshape(np.arange(regression_len, dtype=np.float32), (1, 1, 1, -1))
                bboxes = tf.reduce_sum(bboxes, axis=-1)
            else:
                # the box distribution was already calculated on the nn-core
                bboxes = tf.reshape(box_distribute, (-1, h * w, 4))
            bboxes = bboxes * stride

            # decode box
            bboxes = tf.concat([bboxes[:, :, :2] * (-1), bboxes[:, :, 2:]], axis=-1)
            decode_box = np.expand_dims(centers, axis=0) + bboxes

            # clipping
            xmin = tf.maximum(0.0, decode_box[:, :, 0]) / img_dims[1]
            ymin = tf.maximum(0.0, decode_box[:, :, 1]) / img_dims[0]
            xmax = tf.minimum(tf.cast(img_dims[1], tf.float32), decode_box[:, :, 2]) / img_dims[1]
            ymax = tf.minimum(tf.cast(img_dims[0], tf.float32), decode_box[:, :, 3]) / img_dims[0]
            decode_box = tf.transpose([ymin, xmin, ymax, xmax], [1, 2, 0])

            decoded_bboxes = decode_box if decoded_bboxes is None else tf.concat([decoded_bboxes, decode_box], axis=1)
            modified_scores = self.remove_background_class(scores) if self.config.background_removal else scores
            detection_scores = (
                modified_scores if detection_scores is None else tf.concat([detection_scores, modified_scores], axis=1)
            )

        decoded_bboxes = tf.expand_dims(decoded_bboxes, axis=2)

        return decoded_bboxes, detection_scores

    def remove_background_class(self, detection_scores):
        if self.config.classes > 1:
            if self.config.background_removal_index not in [0, self.config.classes - 1]:
                msg = "unsupported value of `background_removal_index`. The value must be 0 or classes-1"
                raise AccelerasException(
                    msg,
                )

            return (
                tf.slice(detection_scores, [0, 0, 1], [-1, -1, -1])
                if self.config.background_removal_index == 0
                else tf.slice(
                    detection_scores,
                    [0, 0, 0],
                    [-1, -1, self.config.classes - 1],
                )  # [batch_size, num_of_proposals, num_of_classes (wo background)]
            )

        msg = "The number of classes must be greater than 1 for background removal"
        raise AccelerasException(msg)

    @staticmethod
    def yolov5_bbox_decoding(proposed_bbox_centers, proposed_bbox_scales, offsets, branch_anchors, stride):
        # the formula of decoding the centers is (proposed_bbox_centers * 2. - 0.5 + offsets) * stride
        box_centers = (
            proposed_bbox_centers * tf.constant(2, dtype=tf.float32)
            - tf.constant(0.5, dtype=tf.float32)
            + tf.cast(offsets, dtype=tf.float32)
        ) * stride

        # the formula of decoding the scales is (proposed_bbox_scales * 2) ** 2 * anchors_for_stride
        box_scales = (proposed_bbox_scales * tf.constant(2, dtype=tf.float32)) ** tf.constant(
            2,
            dtype=tf.float32,
        ) * tf.cast(branch_anchors, dtype=tf.float32)
        return box_centers, box_scales

    @staticmethod
    def yolox_bbox_decoding(proposed_bbox_centers, proposed_bbox_scales, offsets, branch_anchors, stride):
        # the formula of decoding the centers is (proposed_bbox_centers + offset) * stride
        box_centers = (proposed_bbox_centers + tf.cast(offsets, dtype=tf.float32)) * stride  # dim [N, HxW, anchors, 2]

        # the formula of decoding the scales is (proposed_bbox_centers * stride)
        box_scales = tf.math.exp(proposed_bbox_scales) * stride  # dim [N, HxW, anchors, 2]
        return box_centers, box_scales

    def ssd_decoding_call(self, inputs):
        # decoding ssd outputs

        # organizing the outputs as follows:
        # in case of regression layer:
        # changing the output shape of reg_layer form [batch, h, w, features] to [batch, h * w * num_of anchors, 4]
        # in case of class prediction layer:
        # changing the output shape of cls_layer form [batch, h, w, features] to [batch, h * w * num_of anchors, num_of_classes]
        reshaped_bbox_prediction_tensors = []
        reshaped_class_prediction_tensors = []
        # the bbox decoder info is sorted by the h value in descending order
        # thus the first bbox decoder info corresponds to the output with the largest h value
        outputs_dims = sorted({input.shape[1:3] for input in inputs}, key=lambda x: x[0], reverse=True)
        for i, bbox_decoder in enumerate(self.config.bbox_decoders_info):
            reg_layer = next(
                layer
                for layer in inputs
                if layer.shape[-1] == 4 * len(bbox_decoder["h"]) and layer.shape[1:3] == outputs_dims[i]
            )
            cls_layer = next(
                layer
                for layer in inputs
                if layer.shape[-1] == self.config.classes * len(bbox_decoder["h"])
                and layer.shape[1:3] == outputs_dims[i]
            )
            _, h, w, features = reg_layer.shape  # features = num_of_anchors * 4 (=beta_x, beta_y, beta_w, beta_h)
            num_of_anchors = tf.cast(features / 4, tf.int32)
            spatial_size = h * w * num_of_anchors
            reshaped_bbox_prediction_tensors.append(tf.reshape(reg_layer, [-1, spatial_size, 4]))
            reshaped_class_prediction_tensors.append(tf.reshape(cls_layer, [-1, spatial_size, self.config.classes]))
        # dims [batch, h * w * num_of anchors, 4]
        reshaped_bbox_prediction = tf.concat(reshaped_bbox_prediction_tensors, axis=1)
        # dims [batch, h * w * num_of anchors, num_of_classes]
        detection_scores = tf.concat(reshaped_class_prediction_tensors, axis=1)
        if self.config.background_removal:
            detection_scores = self.remove_background_class(detection_scores)

        # organizing bbox properties - centers and sizes
        batch_size = tf.shape(reshaped_bbox_prediction)[0]
        grid_sizes = [output_branch.shape[1:3] for output_branch in inputs[0::2]]  # list of h of each output branch
        anchors_strides = [(1 / h, 1 / w) for h, w in grid_sizes]
        anchor_offsets = [(0.5 * stride[0], 0.5 * stride[1]) for stride in anchors_strides]

        # praparing bboxes infos - centers points and sizes
        bboxes_info = []  # contains all center points and sizes of the anchors. dims -> [num_of proposals, 4]
        for grid_size, bbox_decoder, anchor_stride, anchor_offset in zip(
            grid_sizes,
            self.config.bbox_decoders_info,
            anchors_strides,
            anchor_offsets,
        ):
            h = tf.convert_to_tensor(value=bbox_decoder[BBoxDecodersInfo.H.value])
            w = tf.convert_to_tensor(value=bbox_decoder[BBoxDecodersInfo.W.value])
            # preparing centers locations of the anchors (considers the height and the width of the image are 1)
            y_centers = np.arange(grid_size[0]) * anchor_stride[0] + anchor_offset[0]
            x_centers = np.arange(grid_size[1]) * anchor_stride[1] + anchor_offset[1]

            # bbox info contains [y_center, x_center, h, w] per specific center point.
            bbox_info_per_center = np.zeros((len(bbox_decoder[BBoxDecodersInfo.H.value]), 4))
            for y_center in y_centers:
                for x_center in x_centers:
                    for i in range(len(bbox_decoder[BBoxDecodersInfo.H.value])):
                        bbox_info_per_center[i, :] = (
                            y_center,
                            x_center,
                            bbox_decoder[BBoxDecodersInfo.H.value][i],
                            bbox_decoder[BBoxDecodersInfo.W.value][i],
                        )
                    bboxes_info.append(bbox_info_per_center.copy())

        bboxes_info = tf.concat(bboxes_info, axis=0)  # dims [num_of_proposals, 4]

        # replicating bbox_info along the batch_size dim
        # dims [batch_size, num_of_proposals, 4]
        bboxes_info = tf.tile(tf.expand_dims(bboxes_info, axis=0), [batch_size, 1, 1])

        # flatten the batch dimension
        bboxes_info = tf.cast(bboxes_info, dtype=tf.float32)

        # decoding formula:
        #   cx = dx + dw * bx                 NN_OUTPUT = (by, bx, bh, bw)
        #   cy = dy + dh * by                 DEFAULT_BOX = (dy, dx, dh, dw)
        #   cw = dw * exp(bw)                 PREDICTION = (cy, cx, ch, cw) where (cy, cx) is the center point of the bbox
        #   ch = dh * exp(bh)
        by, bx, bh, bw = tf.unstack(tf.transpose(a=reshaped_bbox_prediction))
        dy, dx, dh, dw = tf.unstack(tf.transpose(a=bboxes_info))

        # factoring dimensions
        by /= tf.cast(self.config.centers_scale_factor, dtype=tf.float32)
        bx /= tf.cast(self.config.centers_scale_factor, dtype=tf.float32)
        bh /= tf.cast(self.config.bbox_dimensions_scale_factor, dtype=tf.float32)
        bw /= tf.cast(self.config.bbox_dimensions_scale_factor, dtype=tf.float32)

        # decoding
        cy = dy + dh * by
        cx = dx + dw * bx
        cw = dw * tf.exp(bw)
        ch = dh * tf.exp(bh)

        # formatting
        y_min = cy - ch / 2.0
        x_min = cx - cw / 2.0
        y_max = cy + ch / 2.0
        x_max = cx + cw / 2.0

        listed_boxes = [y_min, x_min, y_max, x_max]
        decoded_bboxes = tf.transpose(a=tf.stack(listed_boxes))
        decoded_bboxes = tf.expand_dims(decoded_bboxes, axis=2)

        return decoded_bboxes, detection_scores

    def _combined_non_max_suppression(self, decoded_bboxes, detection_scores):
        nmsed_boxes, nmsed_scores, nmsed_classes, _ = tf.image.combined_non_max_suppression(
            boxes=decoded_bboxes,
            scores=detection_scores,
            score_threshold=self.config.nms_scores_th if self.postprocess_type != PostprocessType.IOU else 0.0,
            iou_threshold=self.config.nms_iou_th,
            max_output_size_per_class=self.config.max_proposals_per_class,
            max_total_size=self.config.max_total_proposals,
        )
        return tf.numpy_function(self._organize_nms_output, [nmsed_boxes, nmsed_scores, nmsed_classes], np.float32)

    def concat_yolox_inputs(self, inputs):
        concatenated_input = []
        inputs_mapping = dict(zip([x.split("/")[-1] for x in self.hn_element["input"]], [x for x in inputs]))

        for bbox_decoder in self.config.bbox_decoders_info:
            # the branch's input has to be concatenated as follows -> [regression_layer, objectness_layer, classes_layer]
            reg_key = bbox_decoder["reg_layer"].split("/")[-1]
            obj_key = bbox_decoder["objectness_layer"].split("/")[-1]
            cls_key = bbox_decoder["cls_layer"].split("/")[-1]
            concatenated_input.append(
                tf.keras.layers.Concatenate(axis=-1)(
                    [inputs_mapping[reg_key], inputs_mapping[obj_key], inputs_mapping[cls_key]]
                ),
            )
        return concatenated_input

    def _organize_nms_output(self, nmsed_boxes, nmsed_scores, nmsed_classes):
        # this function aligns the shape of the nmsed result to as in the nn-core -> [,num_of_classes, 5, max_proposals_per_class]
        batch_size = nmsed_boxes.shape[0]
        # preparing output tensor
        # if background removal is applied number of classes is decreased by 1
        channels = self.config.classes - 1 if self.config.background_removal else self.config.classes

        results = np.zeros((batch_size, channels, 5, self.config.max_proposals_per_class))
        for img_index in range(batch_size):
            for class_index in range(channels):
                indices = np.where(nmsed_classes[img_index, :] == class_index)
                if len(indices[0]):
                    # bboxes were classified with the current class
                    concat_class_results = np.concatenate(
                        [
                            np.squeeze(nmsed_boxes[img_index][indices, :], axis=0),
                            np.transpose(nmsed_scores[img_index, indices]),
                        ],
                        axis=-1,
                    )
                    # clears rows of zeros - helps with independency between max_proposals_per_class and
                    # max_total_output_proposals
                    concat_class_results = concat_class_results[~np.all(concat_class_results == 0, axis=1)]
                    # stores class's bbox in the output tensor
                    results[
                        img_index,
                        class_index,
                        :,
                        : concat_class_results.shape[0],
                    ] = np.expand_dims(np.transpose(concat_class_results), axis=0)
        return results.astype(np.float32)

    def _organize_yolov5_seg_output(self, outputs):
        # outputs is a list of per-image detections, where each is a list of:
        # [boxes ([y_min, x_min, y_max x_max]), classes, scores, flattened masks] with sizes:
        # [(4, max_proposals), (1, max_proposals), (1, max_proposals) (image_dims[0] * image_dims[1], max_proposals)].
        organized_outputs = []
        batch_size = len(outputs)
        width = 6 + self.config.image_dims[0] * self.config.image_dims[1]  # 6 = 4 for bbox + 1 for score + 1 for class
        organized_outputs = np.zeros(
            (batch_size, 1, width, self.config.max_total_proposals),
            dtype=np.float32,
        )  # [batch_size, 1, 6, max_proposals]
        for output_index, output in enumerate(outputs):
            # transpose outputs
            output[0] = output[0].transpose([1, 0])
            output[1] = output[1].transpose([1, 2, 0])
            output[2] = output[2][None, :]
            output[3] = output[3][None, :]

            max_proposals = self._params[NMSProperties.MAX_TOTAL_OUTPUT_PROPOSALS.value]
            num_of_proposals = output[1].shape[-1]
            num_of_proposals = min(max_proposals, num_of_proposals)
            if num_of_proposals < max_proposals:
                # padds with zeros
                for i in range(len(output)):
                    pad_width = [(0, 0)] * len(output[i].shape)
                    pad_width[-1] = (0, max_proposals - num_of_proposals)
                    output[i] = np.pad(output[i], pad_width)
            else:
                for i in range(len(output)):
                    output[i] = output[i][..., :num_of_proposals]

            flattened_mask = output[1].reshape(-1, max_proposals)
            bboxes, classes, scores = output[0], output[2], output[3]

            # normalizes bbox coordinates to be between 0-1
            bboxes[0::2, :] /= self.config.image_dims[0]
            bboxes[1::2, :] /= self.config.image_dims[1]
            flattened_output = np.concatenate([bboxes, scores, classes, flattened_mask], axis=0)
            organized_outputs[output_index] = flattened_output[None, :]

        # organized_outputs format is [N, 1, num_max_proposals, 6 + image_dims[0] * image_dims[1]]
        # where the last dimension is [y_min, x_min, y_max x_max, score, class, flattened masks]
        return organized_outputs

    def logits_call(self, inputs):
        if self._params[LogitsPostprocessLayerParams.TYPE.value] == LayerType.SOFTMAX.value:
            # the default value of an axis is -1
            axis = (
                self._params[LogitsPostprocessLayerParams.AXIS.value]
                if self._params[LogitsPostprocessLayerParams.AXIS.value]
                else -1
            )
            return tf.keras.activations.softmax(inputs, axis=axis)
        elif self._params[LogitsPostprocessLayerParams.TYPE.value] == LayerType.ARGMAX.value:
            axis = (
                self._params[LogitsPostprocessLayerParams.AXIS.value]
                if self._params[LogitsPostprocessLayerParams.AXIS.value]
                else 0
            )
            return tf.cast(tf.expand_dims(tf.math.argmax(inputs, axis=axis), axis=-1), tf.float32)

    def resize_call(self, inputs):
        resize_method = self._params[ResizePostprocessLayerParams.RESIZE_METHOD.value]
        align_corners = (
            self._params[ResizePostprocessLayerParams.PIXELS_MODE.value] == ResizeBilinearPixelsMode.ALIGN_CORNERS.value
        )
        half_pixels = (
            self._params[ResizePostprocessLayerParams.PIXELS_MODE.value] == ResizeBilinearPixelsMode.HALF_PIXELS.value
        )

        if self._params[ResizePostprocessLayerParams.RESIZE_METHOD.value] == ResizeMethod.BILINEAR.value:
            resize_function = tf.compat.v1.image.resize_bilinear
        elif self._params[ResizePostprocessLayerParams.RESIZE_METHOD.value] == ResizeMethod.NEAREST_NEIGHBOR.value:
            resize_function = tf.compat.v1.image.resize_nearest_neighbor
        else:
            raise AccelerasException(f"The resize method {resize_method} is not supported as postprocess layer")
        return resize_function(
            images=inputs,
            size=self._params[ResizePostprocessLayerParams.RESIZE_SHAPE.value],
            align_corners=align_corners,
            half_pixel_centers=half_pixels,
        )

    def update_scale_scalar_dof(self, shift):
        pass

    def is_jit_compile_supported(self, training=False):
        return self.postprocess_type not in {PostprocessType.NMS}


def process_mask(protos, masks_in, bboxes, shape, upsample=True, mask_threshold=0.5):
    mh, mw, c = protos.shape
    masks = _sigmoid(np.matmul(masks_in, protos.reshape((-1, c)).transpose((1, 0)))).reshape((-1, mh, mw))
    downsampled_bboxes = bboxes.copy()
    if upsample:
        if not masks.shape[0]:
            return np.resize(masks, [0, *shape])
        masks = tf.image.resize(np.transpose(masks, axes=(1, 2, 0)), shape).numpy()
        if len(masks.shape) == 2:
            masks = masks[..., np.newaxis]
        masks = np.transpose(masks, axes=(2, 0, 1))  # CHW
        masks = crop_mask(masks, downsampled_bboxes)  # CHW
    return masks > mask_threshold


def crop_mask(masks, boxes):
    """
    Zeroing out mask region outside of the predicted bbox.

    Args:
        masks: numpy array of masks with shape [n, h, w]
        boxes: numpy array of bbox coords with shape [n, 4]

    """
    n_masks, h, w = masks.shape
    y1, x1, y2, x2 = np.array_split(boxes[:, :, None], 4, axis=1)
    rows = np.arange(w)[None, None, :]
    cols = np.arange(h)[None, :, None]
    return masks * ((rows >= x1) * (rows < x2) * (cols >= y1) * (cols < y2))


def _yolov5_decoding(stride, anchors, classes, output):
    BS, H, W = output.shape[0:3]
    num_anchors = len(anchors) // 2
    grid, anchor_grid = _make_grid(anchors, stride, BS, W, H)
    output = output.transpose((0, 3, 1, 2)).reshape((BS, num_anchors, -1, H, W)).transpose((0, 1, 3, 4, 2))
    xy, wh, conf, mask = np.array_split(output, [2, 4, 4 + classes + 1], axis=4)
    # decoding
    xy = (_sigmoid(xy) * 2 + grid) * stride
    wh = (_sigmoid(wh) * 2) ** 2 * anchor_grid
    out = np.concatenate((xy, wh, _sigmoid(conf), mask), 4)
    out = out.reshape((BS, num_anchors * H * W, -1)).astype(np.float32)
    return out


def _sigmoid(x):
    return 1 / (1 + np.exp(-x))


def _make_grid(anchors, stride, bs=8, nx=20, ny=20):
    na = len(anchors) // 2
    y, x = np.arange(ny), np.arange(nx)
    yv, xv = np.meshgrid(y, x, indexing="ij")

    grid = np.stack((xv, yv), 2)
    grid = np.stack([grid for _ in range(na)], 0) - 0.5
    grid = np.stack([grid for _ in range(bs)], 0)

    anchor_grid = np.reshape(anchors * stride, (na, -1))
    anchor_grid = np.stack([anchor_grid for _ in range(ny)], axis=1)
    anchor_grid = np.stack([anchor_grid for _ in range(nx)], axis=2)
    anchor_grid = np.stack([anchor_grid for _ in range(bs)], 0)

    return grid, anchor_grid


def non_max_suppression_yolo_seg(
    prediction,
    conf_thres=0.25,
    iou_thres=0.45,
    max_det=300,
    n_masks=32,
    multi_label=True,
):
    """
    Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
    Args:
        prediction: numpy.ndarray with shape (batch_size, num_proposals, 117)
        conf_thres: confidence threshold for NMS
        iou_thres: IoU threshold for NMS
        max_det: Maximal number of detections to keep after NMS
        n_masks: Number of masks
        multi_label: Consider only best class per proposal or all conf_thresh passing proposals
    Returns:
         A list of per image detections, where each is a list with the following structure:
         [np.ndarray of shape (num_detections, 4),
          np.ndarray of shape (num_detections, 32),
          np.ndarray of shape (num_detections, ),
          np.ndarray of shape (num_detections, )]
         Corresponding to detection_boxes, mask, detection_classes, detection_scores.
    """
    n_classes = prediction.shape[2] - n_masks - 5
    candidates = prediction[..., 4] > conf_thres
    max_wh = 7680  # (pixels) maximum box width and height
    mask_index = 5 + n_classes
    output = []
    for img_idx, out in enumerate(prediction):  # image index, image inference
        out = out[candidates[img_idx]]  # confidence
        if out.shape[0]:  # If something remained to process
            out[:, 5:] *= out[:, 4:5]  # Confidence = Objectness x Class Score
            boxes = xywh_to_yxyx(out[:, :4])  # (center_x, center_y, width, height) to (y1, x1, y2, x2)
            mask = out[:, mask_index:]
            multi_label = multi_label and n_classes > 1
            if not multi_label:
                conf = np.expand_dims(out[:, 5:mask_index].max(1), 1)
                j = np.expand_dims(out[:, 5:mask_index].argmax(1), 1).astype(np.float32)
                keep = np.squeeze(conf, 1) > conf_thres
                out = np.concatenate((boxes, conf, j, mask), 1)[keep]
            else:
                i, j = (out[:, 5:mask_index] > conf_thres).nonzero()
                out = np.concatenate((boxes[i], out[i, 5 + j, None], j[:, None].astype(np.float32), mask[i]), 1)

            out = out[out[:, 4].argsort()[::-1]]  # sort by confidence
            # per-class NMS
            cls_shift = out[:, 5:6] * max_wh
            boxes = out[:, :4] + cls_shift
            conf = out[:, 4:5]
            preds = np.hstack([boxes.astype(np.float32), conf.astype(np.float32)])

            keep = nms_indices(preds, iou_thres)
            if keep.shape[0] > max_det:
                keep = keep[:max_det]
            out = out[keep]
            out = [out[:, :4], out[:, 6:], out[:, 5], out[:, 4]]  # boxes, masks, classes, scores
        else:
            # zero detections
            out = [
                np.zeros((1, 4), dtype=np.float32),
                np.zeros((1, n_masks), dtype=np.float32),
                np.zeros((1,), dtype=np.float32),
                np.zeros((1,), dtype=np.float32),
            ]
        output.append(out)
    return output


def nms_indices(dets, thresh):
    y1 = dets[:, 0]
    x1 = dets[:, 1]
    y2 = dets[:, 2]
    x2 = dets[:, 3]
    scores = dets[:, 4]
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]
    ndets = dets.shape[0]
    suppressed = np.zeros((ndets), dtype=int)

    for _i in range(ndets):
        i = order[_i]
        if suppressed[i] != 1:
            ix1 = x1[i]
            iy1 = y1[i]
            ix2 = x2[i]
            iy2 = y2[i]
            iarea = areas[i]
            for _j in range(_i + 1, ndets):
                j = order[_j]
                if suppressed[j] != 1:
                    xx1 = max(ix1, x1[j])
                    yy1 = max(iy1, y1[j])
                    xx2 = min(ix2, x2[j])
                    yy2 = min(iy2, y2[j])
                    w = max(0.0, xx2 - xx1 + 1)
                    h = max(0.0, yy2 - yy1 + 1)
                    inter = w * h
                    ovr = inter / (iarea + areas[j] - inter)
                    if ovr >= thresh:
                        suppressed[j] = 1

    return np.where(suppressed == 0)[0]


def xywh_to_yxyx(x):
    y = np.copy(x)
    y[:, 0] = x[:, 1] - x[:, 3] / 2  # ymin
    y[:, 1] = x[:, 0] - x[:, 2] / 2  # xmin
    y[:, 2] = x[:, 1] + x[:, 3] / 2  # ymax
    y[:, 3] = x[:, 0] + x[:, 2] / 2  # xmax
    return y


META_ARCH_TO_BBOX_DECODER_FUNC = {
    NMSOnCpuMetaArchitectures.YOLOV5: HailoPostprocess.yolov5_bbox_decoding,
    NMSOnCpuMetaArchitectures.YOLOX: HailoPostprocess.yolox_bbox_decoding,
}
