#!/usr/bin/env python
import copy
import os
from collections import OrderedDict

import jsonref
import jsonschema
import numpy as np

import hailo_sdk_client
from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import PostProcessConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BBoxDecodersInfo,
    NMSProperties,
    PaddingType,
    PostprocessTarget,
    PostprocessType,
    ProtoInfo,
)
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    FeatureMultiplierType,
    LayerType,
    NMSMetaArchitectures,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    BboxDecoderLayer,
    ConcatLayer,
    FeatureMultiplierLayer,
    FusedBboxDecoderLayer,
    FusedConv2DLayer,
    FusedSliceLayer,
    FusedStandaloneActivationLayer,
    NMSLayer,
    OutputLayer,
    PostprocessLayer,
    ProposalGeneratorLayer,
    SoftmaxLayer,
)
from hailo_sdk_common.hailo_nn.nms_postprocess_defaults import (
    DEFAULT_BACKGROUND_REMOVAL,
    DEFAULT_BACKGROUND_REMOVAL_INDEX,
    DEFAULT_BBOX_DIMENSIONS_SCALE_FACTOR,
    DEFAULT_CENTERS_SCALE_FACTOR,
    DEFAULT_IMAGE_DIMS,
    DEFAULT_INPUT_DIVISION_FACTOR,
    DEFAULT_IOU_TH,
    DEFAULT_MAX_PROPOSALS_PER_CLASS,
    DEFAULT_MAX_TOTAL_PROPOSALS,
    DEFAULT_NMS_OUTPUT_ORIGINAL_NAME,
    DEFAULT_REGRESSION_PREDICT_ORDER,
    DEFAULT_SCORES_TH,
    DEFAULT_SSD_CLASSES,
    DEFAULT_YOLO_REG_LENGTH,
    DEFAULT_YOLO_SEG_MASK_THRESH,
)
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import InnerParams, ModelParams
from hailo_sdk_common.numeric_utils.numeric_utils import (
    get_bbox_centers_for_centernet,
    get_bbox_centers_for_ssd,
    get_bbox_centers_for_yolo,
)

logger = default_logger()

MIN_CONCAT_LAYERS = 3
MAX_CONCAT_LAYERS = 5

CONCAT_LAYER_NAME_FORMAT = "{}_concat"
NEW_LAYER_NAME_FORMAT = "{}_{}{}"

MULTIPLE_ACTIVATION_OPS = [ActivationType.linear, ActivationType.exp]
MULTIPLE_ACTIVATION_FEATURES_SLICES = [slice(None, 2, None), slice(2, None, None)]
MULTIPLE_ACTIVATION_FEATURES_RESHAPE_FACTOR = 4
YOLOV5_NN_CORE_SUPPORTED_BRANCHES = 3


class SSDPostprocessException(Exception):
    pass


class YOLOv5PostprocessException(Exception):
    pass


class YOLOv6PostprocessException(Exception):
    pass


class NMSConfigPostprocessException(Exception):
    pass


class UnsupportedMetaArchError(Exception):
    pass


class BBoxDecoderInputLayerNotFoundException(SSDPostprocessException):
    pass


class NMSMetaData:
    def __init__(self, nms_config, meta_arch, engine, config_file):
        self._nms_config = nms_config
        self._meta_arch = meta_arch
        self._engine = engine
        self._config_file = config_file

    @property
    def nms_config(self):
        return self._nms_config

    @property
    def meta_arch(self):
        return self._meta_arch

    @property
    def engine(self):
        return self._engine

    @property
    def config_file(self):
        return self._config_file

    def to_pb(self, pb_wrapper):
        return self.nms_config.to_pb(pb_wrapper)


class NMSConfig:
    config: dict

    def __init__(self):
        self._meta_arch = None
        self._bbox_decoders_info = []
        self._anchors = OrderedDict()
        self._anchors_stride = OrderedDict()
        self._nms_scores_th = DEFAULT_SCORES_TH
        self._nms_iou_th = DEFAULT_IOU_TH
        self._max_proposals_per_class = DEFAULT_MAX_PROPOSALS_PER_CLASS
        self._centers_scale_factor = DEFAULT_CENTERS_SCALE_FACTOR
        self._bbox_dimensions_scale_factor = DEFAULT_BBOX_DIMENSIONS_SCALE_FACTOR
        self._classes = DEFAULT_SSD_CLASSES
        self._background_removal = DEFAULT_BACKGROUND_REMOVAL
        self._background_removal_index = DEFAULT_BACKGROUND_REMOVAL_INDEX
        self._input_division_factor = DEFAULT_INPUT_DIVISION_FACTOR
        self._regression_prediction_order = DEFAULT_REGRESSION_PREDICT_ORDER
        self._image_dims = DEFAULT_IMAGE_DIMS
        self._schema_filename = ""
        self._cls_activation = ActivationType.sigmoid
        self._proto_info = OrderedDict()
        self._mask_threshold = DEFAULT_YOLO_SEG_MASK_THRESH
        self._regression_length = DEFAULT_YOLO_REG_LENGTH
        self._max_total_proposals = DEFAULT_MAX_TOTAL_PROPOSALS
        self._dfl_on_nn_core = False

    @classmethod
    def from_json(
        cls,
        config: dict,
        meta_arch: NMSMetaArchitectures = NMSMetaArchitectures.SSD,
        nms_iou_th=DEFAULT_IOU_TH,
        schema_filename="",
    ):
        nms_config = cls()
        nms_config.config = config
        nms_config.schema_filename = (
            schema_filename if schema_filename else f"./nms_{meta_arch.value}_config.schema.json"
        )
        nms_config.validate_config_schema(config)

        nms_config.meta_arch = meta_arch
        nms_config.nms_scores_th = config.get("nms_scores_th", DEFAULT_SCORES_TH)
        nms_config.nms_iou_th = config.get("nms_iou_th", nms_iou_th)
        nms_config.max_proposals_per_class = config.get("max_proposals_per_class", DEFAULT_MAX_PROPOSALS_PER_CLASS)
        nms_config.centers_scale_factor = config.get("centers_scale_factor", DEFAULT_CENTERS_SCALE_FACTOR)
        nms_config.classes = config.get("classes", DEFAULT_SSD_CLASSES)
        nms_config.background_removal = config.get("background_removal", DEFAULT_BACKGROUND_REMOVAL)
        nms_config.background_removal_index = config.get("background_removal_index", DEFAULT_BACKGROUND_REMOVAL_INDEX)
        nms_config.bbox_decoders_info = config["bbox_decoders"]
        nms_config.input_division_factor = config.get("input_division_factor", DEFAULT_INPUT_DIVISION_FACTOR)
        nms_config.regression_prediction_order = config.get(
            "regression_prediction_order",
            DEFAULT_REGRESSION_PREDICT_ORDER,
        )
        nms_config.image_dims = config.get("image_dims", DEFAULT_IMAGE_DIMS)
        nms_config.proto_info = dict(config["proto"][0]) if "proto" in config else {}
        nms_config.bbox_dimensions_scale_factor = config.get(
            "bbox_dimensions_scale_factor",
            DEFAULT_BBOX_DIMENSIONS_SCALE_FACTOR,
        )
        nms_config.mask_threshold = config.get("mask_threshold", DEFAULT_YOLO_SEG_MASK_THRESH)
        nms_config.regression_length = config.get("regression_length", DEFAULT_YOLO_REG_LENGTH)
        nms_config.max_total_proposals = config.get("max_total_proposals", DEFAULT_MAX_TOTAL_PROPOSALS)

        for bbox_decoder in nms_config.bbox_decoders_info:
            if meta_arch in [
                NMSMetaArchitectures.SSD,
                NMSMetaArchitectures.YOLOX,
                NMSMetaArchitectures.YOLOV8,
                NMSMetaArchitectures.DAMOYOLO,
            ]:
                key = (
                    BBoxDecodersInfo.REG_LAYER.value
                    if BBoxDecodersInfo.REG_LAYER.value in bbox_decoder
                    else BBoxDecodersInfo.COMBINED_LAYER.value
                )
                bbox_decoder[BBoxDecodersInfo.NAME.value] = bbox_decoder[key].replace("conv", "bbox_decoder")
            elif meta_arch == NMSMetaArchitectures.CENTERNET:
                reg_layer_h = bbox_decoder[BBoxDecodersInfo.REG_LAYER_H.value]
                reg_layer_w = bbox_decoder[BBoxDecodersInfo.REG_LAYER_W.value]
                bbox_decoder_name = f'bbox_decoder_{reg_layer_h.split("/")[-1]}{reg_layer_w.split("/")[-1]}'
                bbox_decoder[BBoxDecodersInfo.NAME.value] = bbox_decoder_name
            elif meta_arch == NMSMetaArchitectures.YOLOV5:
                bbox_decoder_name = bbox_decoder[BBoxDecodersInfo.ENCODED_LAYER.value].replace("conv", "bbox_decoder")
                bbox_decoder[BBoxDecodersInfo.NAME.value] = bbox_decoder_name

            nms_config.anchors[bbox_decoder[BBoxDecodersInfo.NAME.value]] = (
                list(zip(bbox_decoder[BBoxDecodersInfo.H.value], bbox_decoder[BBoxDecodersInfo.W.value]))
                if BBoxDecodersInfo.H.value in bbox_decoder
                else [(1, 1)]
            )
            if BBoxDecodersInfo.STRIDE.value in bbox_decoder:
                nms_config.anchors_stride[bbox_decoder[BBoxDecodersInfo.NAME.value]] = bbox_decoder[
                    BBoxDecodersInfo.STRIDE.value
                ]

        return nms_config

    def to_post_config(self) -> PostProcessConfig:
        """
        Converts the existing configuration into a structured Config object.

        Returns:
            Config: The structured configuration object populated with the current settings.
        """
        return PostProcessConfig(
            background_removal=self.background_removal,
            background_removal_index=self.background_removal_index,
            bbox_decoders_info=self.bbox_decoders_info,
            bbox_dimensions_scale_factor=self.bbox_dimensions_scale_factor,
            centers_scale_factor=self.centers_scale_factor,
            classes=self.classes,
            image_dims=self.image_dims,
            mask_threshold=self.mask_threshold,
            max_proposals_per_class=self.max_proposals_per_class,
            max_total_proposals=self.max_total_proposals,
            nms_iou_th=self.nms_iou_th,
            nms_scores_th=self.nms_scores_th,
            proto_info=self.proto_info,
            regression_length=self.regression_length,
            dfl_on_nn_core=self.dfl_on_nn_core,
        )

    def validate_config_schema(self, config_json):
        work_dir = "core_postprocess"
        schema_full_path = os.path.join(hailo_sdk_client.tools.__path__[0], work_dir, self.schema_filename)
        with open(schema_full_path) as schema:
            config_schema = jsonref.load(schema, base_uri=f"file:{schema_full_path}", jsonschema=True)
            jsonschema.validate(config_json, config_schema)

    def to_pb(self, pb_wrapper):
        nms_metadata_message = pb_wrapper.integrated_hw_graph_base_pb2.ProtoNMSMetaData()
        nms_metadata_message.meta_arch = pb_wrapper.META_ARCH_TYPE_TO_PB[self.meta_arch]
        nms_metadata_message.image_dims.h = self.image_dims[0]
        nms_metadata_message.image_dims.w = self.image_dims[1]

        # copies bbox decoders info
        for bbox_decoder_info in self.bbox_decoders_info:
            bbox_decoder_info_message = pb_wrapper.integrated_hw_graph_base_pb2.ProtoBBoxDecoderInfo()
            bbox_decoder_info_message.cls_layer = bbox_decoder_info.get(BBoxDecodersInfo.CLS_LAYER.value, "")
            bbox_decoder_info_message.reg_layer_h = bbox_decoder_info.get(BBoxDecodersInfo.REG_LAYER_H.value, "")
            bbox_decoder_info_message.reg_layer_w = bbox_decoder_info.get(BBoxDecodersInfo.REG_LAYER_W.value, "")
            bbox_decoder_info_message.h.extend(bbox_decoder_info.get(BBoxDecodersInfo.H.value, []))
            bbox_decoder_info_message.w.extend(bbox_decoder_info.get(BBoxDecodersInfo.W.value, []))
            bbox_decoder_info_message.stride = bbox_decoder_info.get(BBoxDecodersInfo.STRIDE.value, 0)
            bbox_decoder_info_message.reg_layer = bbox_decoder_info.get(BBoxDecodersInfo.REG_LAYER.value, "")
            bbox_decoder_info_message.encoded_layer = bbox_decoder_info.get(BBoxDecodersInfo.ENCODED_LAYER.value, "")
            bbox_decoder_info_message.objectness_layer = bbox_decoder_info.get(BBoxDecodersInfo.OBJ_LAYER.value, "")
            nms_metadata_message.bbox_decoders_info.append(bbox_decoder_info_message)

        nms_metadata_message.nms_scores_th = self.nms_scores_th
        nms_metadata_message.centers_scale_factor = self.centers_scale_factor
        nms_metadata_message.bbox_dimensions_scale_factor = self.bbox_dimensions_scale_factor
        nms_metadata_message.background_removal = self.background_removal
        nms_metadata_message.background_removal_index = self.background_removal_index
        nms_metadata_message.input_division_factor = self.input_division_factor
        nms_metadata_message.schema_filename = self.schema_filename
        nms_metadata_message.cls_activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self.cls_activation]
        nms_metadata_message.proto_info.number = self.proto_info.get(ProtoInfo.NUMBER.value, 0)
        nms_metadata_message.proto_info.stride = self.proto_info.get(ProtoInfo.STRIDE.value, 0)
        nms_metadata_message.proto_info.proto_layer = self.proto_info.get(ProtoInfo.PROTO_LAYER.value, "")
        nms_metadata_message.mask_threshold = self.mask_threshold
        nms_metadata_message.regression_length = self.regression_length
        nms_metadata_message.max_total_proposals = self.max_total_proposals
        nms_metadata_message.dfl_on_nn_core = self.dfl_on_nn_core

        return nms_metadata_message

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, name):
        self._name = name

    @property
    def bbox_decoders_info(self):
        return self._bbox_decoders_info

    @bbox_decoders_info.setter
    def bbox_decoders_info(self, bbox_decoders_info):
        self._bbox_decoders_info = bbox_decoders_info

    @property
    def anchors(self):
        return self._anchors

    @anchors.setter
    def anchors(self, anchors):
        self._anchors = anchors

    @property
    def anchors_stride(self):
        return self._anchors_stride

    @anchors_stride.setter
    def anchors_stride(self, anchors_stride):
        self._anchors_stride = anchors_stride

    @property
    def nms_scores_th(self):
        return self._nms_scores_th

    @nms_scores_th.setter
    def nms_scores_th(self, nms_scores_th):
        self._nms_scores_th = nms_scores_th

    @property
    def nms_iou_th(self):
        return self._nms_iou_th

    @nms_iou_th.setter
    def nms_iou_th(self, nms_iou_th):
        self._nms_iou_th = nms_iou_th

    @property
    def max_proposals_per_class(self):
        return self._max_proposals_per_class

    @max_proposals_per_class.setter
    def max_proposals_per_class(self, max_proposals_per_class):
        self._max_proposals_per_class = max_proposals_per_class

    @property
    def centers_scale_factor(self):
        return self._centers_scale_factor

    @centers_scale_factor.setter
    def centers_scale_factor(self, centers_scale_factor):
        self._centers_scale_factor = centers_scale_factor

    @property
    def bbox_dimensions_scale_factor(self):
        return self._bbox_dimensions_scale_factor

    @bbox_dimensions_scale_factor.setter
    def bbox_dimensions_scale_factor(self, bbox_dimensions_scale_factor):
        self._bbox_dimensions_scale_factor = bbox_dimensions_scale_factor

    @property
    def classes(self):
        return self._classes

    @classes.setter
    def classes(self, classes):
        self._classes = classes

    @property
    def background_removal(self):
        return self._background_removal

    @background_removal.setter
    def background_removal(self, background_removal):
        self._background_removal = background_removal

    @property
    def background_removal_index(self):
        return self._background_removal_index

    @background_removal_index.setter
    def background_removal_index(self, background_removal_index):
        self._background_removal_index = background_removal_index

    @property
    def input_division_factor(self):
        return self._input_division_factor

    @input_division_factor.setter
    def input_division_factor(self, input_division_factor):
        self._input_division_factor = input_division_factor

    @property
    def regression_prediction_order(self):
        return self._regression_prediction_order

    @regression_prediction_order.setter
    def regression_prediction_order(self, regression_prediction_order):
        self._regression_prediction_order = regression_prediction_order

    @property
    def meta_arch(self):
        return self._meta_arch

    @meta_arch.setter
    def meta_arch(self, meta_arch):
        self._meta_arch = meta_arch

    @property
    def image_dims(self):
        return self._image_dims

    @image_dims.setter
    def image_dims(self, image_dims):
        self._image_dims = image_dims

    @property
    def schema_filename(self):
        return self._schema_filename

    @schema_filename.setter
    def schema_filename(self, schema_filename):
        self._schema_filename = schema_filename

    @property
    def cls_activation(self):
        return self._cls_activation

    @cls_activation.setter
    def cls_activation(self, cls_activation):
        self._cls_activation = cls_activation

    @property
    def proto_info(self):
        return self._proto_info

    @proto_info.setter
    def proto_info(self, proto_info):
        self._proto_info = proto_info

    @property
    def mask_threshold(self):
        return self._mask_threshold

    @mask_threshold.setter
    def mask_threshold(self, thresh):
        self._mask_threshold = thresh

    @property
    def regression_length(self):
        return self._regression_length

    @regression_length.setter
    def regression_length(self, regression_length):
        self._regression_length = regression_length

    @property
    def max_total_proposals(self):
        return self._max_total_proposals

    @max_total_proposals.setter
    def max_total_proposals(self, max_total_proposals):
        self._max_total_proposals = max_total_proposals

    @property
    def dfl_on_nn_core(self):
        return self._dfl_on_nn_core

    @dfl_on_nn_core.setter
    def dfl_on_nn_core(self, dfl_on_nn_core):
        self._dfl_on_nn_core = dfl_on_nn_core


class BBoxDecoderParams:
    def __init__(self, layer_name, anchors_heights, anchors_widths):
        self.layer_name = layer_name
        self.anchors_heights = anchors_heights
        self.anchors_widths = anchors_widths


class BBoxDecoderParamsSSD(BBoxDecoderParams):
    def get_params(self, hn):
        layer = hn.get_layer_by_name(self.layer_name)
        if layer is None:
            raise SSDPostprocessException(f"Layer {self.layer_name} is not in the HN.")

        _, height, weight, features = layer.input_shapes[0]
        num_of_anchors = features // 2
        y_centers, x_centers = get_bbox_centers_for_ssd(height, weight, num_of_anchors)
        return {
            f"{layer.name}/anchors_heights:0": self.anchors_heights,
            f"{layer.name}/anchors_widths:0": self.anchors_widths,
            f"{layer.name}/anchors_heights_div_2:0": self.anchors_heights / 2,
            f"{layer.name}/anchors_widths_div_2:0": self.anchors_widths / 2,
            f"{layer.name}/anchors_heights_minus_div_2:0": -self.anchors_heights / 2,
            f"{layer.name}/anchors_widths_minus_div_2:0": -self.anchors_widths / 2,
            f"{layer.name}/y_centers:0": y_centers,
            f"{layer.name}/x_centers:0": x_centers,
        }


class BBoxDecoderParamsCenternet(BBoxDecoderParams):
    def get_params(self, hn):
        layer = hn.get_layer_by_name(self.layer_name)
        if layer is None:
            raise SSDPostprocessException(f"Layer {self.layer_name} is not in the HN.")

        _, height, width, features = layer.input_shapes[0]
        num_of_anchors = features // 2
        y_centers, x_centers = get_bbox_centers_for_centernet(height, width, num_of_anchors)
        return {
            f"{layer.name}/anchors_heights:0": self.anchors_heights / height,
            f"{layer.name}/anchors_widths:0": self.anchors_widths / width,
            f"{layer.name}/anchors_heights_div_2:0": self.anchors_heights / 2 / height,
            f"{layer.name}/anchors_widths_div_2:0": self.anchors_widths / 2 / width,
            f"{layer.name}/anchors_heights_minus_div_2:0": -self.anchors_heights / 2 / height,
            f"{layer.name}/anchors_widths_minus_div_2:0": -self.anchors_widths / 2 / width,
            f"{layer.name}/y_centers:0": y_centers,
            f"{layer.name}/x_centers:0": x_centers,
        }


class BBoxDecoderParamsYOLOv5(BBoxDecoderParams):
    def __init__(self, layer_name, anchors_heights, anchors_widths, anchors_stride):
        super().__init__(layer_name, anchors_heights, anchors_widths)
        self.anchors_stride = anchors_stride

    def get_params(self, hn):
        layer = hn.get_layer_by_name(self.layer_name)
        if layer is None:
            raise SSDPostprocessException(f"Layer {self.layer_name} is not in the HN.")

        _, height, width, features = layer.input_shapes[0]
        num_of_anchors = features // 4
        y_centers, x_centers = get_bbox_centers_for_yolo(NMSMetaArchitectures.YOLOV5, height, width, num_of_anchors)
        return {
            f"{layer.name}/anchors_heights:0": np.array([2] * num_of_anchors) / height,
            f"{layer.name}/anchors_widths:0": np.array([2] * num_of_anchors) / width,
            f"{layer.name}/anchors_heights_div_2:0": self.anchors_heights * 2 / height / self.anchors_stride,
            f"{layer.name}/anchors_widths_div_2:0": self.anchors_widths * 2 / width / self.anchors_stride,
            f"{layer.name}/anchors_heights_minus_div_2:0": -self.anchors_heights * 2 / height / self.anchors_stride,
            f"{layer.name}/anchors_widths_minus_div_2:0": -self.anchors_widths * 2 / width / self.anchors_stride,
            f"{layer.name}/y_centers:0": y_centers,
            f"{layer.name}/x_centers:0": x_centers,
        }


class BBoxDecoderParamsYOLOv6:
    def __init__(self, layer_name, anchors_stride):
        self.layer_name = layer_name
        self.anchors_stride = anchors_stride
        self.cls_activation = ActivationType.sigmoid

    def get_params(self, hn):
        layer = hn.get_layer_by_name(self.layer_name)
        if layer is None:
            raise YOLOv6PostprocessException(f"Layer {self.layer_name} is not in the HN.")

        height, width = layer.input_shapes[0][1:3]
        y_centers, x_centers = get_bbox_centers_for_yolo(NMSMetaArchitectures.YOLOV6, height, width, num_of_anchors=1)

        return {
            f"{layer.name}/height_scale_factor:0": 1 / height,
            f"{layer.name}/width_scale_factor:0": 1 / width,
            f"{layer.name}/y_centers:0": y_centers,
            f"{layer.name}/x_centers:0": x_centers,
        }


class NMSPostProcess:
    def __init__(self, hn, weights, config_file, engine, enforce_iou_threshold=True, bbox_decoding_only=False):
        self._hn = hn
        if weights:
            self._weights = weights
        else:
            self._weights = {}
        self._config_file = config_file
        self._engine = engine
        self._enforce_iou_threshold = enforce_iou_threshold
        self._bbox_decoding_only = bbox_decoding_only
        self._sigmoid_layers = []  # layers that their activation is being changed to sigmoid
        self._meta_arch = None
        self._inputs_keys = []
        self._cls_activation = ActivationType.sigmoid
        self._fuser_helper = FuserHelper(self._hn)
        self._output_original_name = DEFAULT_NMS_OUTPUT_ORIGINAL_NAME

    def _load_json(self, meta_arch, nms_iou_th=DEFAULT_IOU_TH):
        self._config = NMSConfig.from_json(self._config_file, meta_arch, nms_iou_th)
        if not self._enforce_iou_threshold and self._config.input_division_factor != 1:
            logger.debug("IOU threshold set to 1 in nms postprocess")
            self._config.nms_iou_th = 1.0

    @property
    def hn(self):
        return self._hn

    @property
    def weights(self):
        return self._weights

    @property
    def config(self):
        return self._config

    @property
    def sigmoid_layers(self):
        return self._sigmoid_layers

    @property
    def cls_activation(self):
        return self._cls_activation

    @property
    def meta_arch(self):
        return self._meta_arch

    @meta_arch.setter
    def meta_arch(self, arch):
        self._meta_arch = arch

    @property
    def inputs_keys(self):
        return self._inputs_keys

    @inputs_keys.setter
    def inputs_keys(self, keys):
        self._inputs_keys = keys

    @property
    def dfl_on_nn_core(self):
        return self._dfl_on_nn_core

    @dfl_on_nn_core.setter
    def dfl_on_nn_core(self, dfl_on_nn_core):
        self._dfl_on_nn_core = dfl_on_nn_core

    @property
    def output_original_name(self):
        return self._output_original_name

    @output_original_name.setter
    def output_original_name(self, output_original_name):
        self._output_original_name = output_original_name

    def _add_bbox_decoder_weights(self):
        bbox_decoder_params = self._get_bbox_decoders_list()
        if bbox_decoder_params:
            for layer_params in bbox_decoder_params:
                self._weights.update(layer_params.get_params(self._hn))
                logger.debug(f"Added params for layer: {layer_params.layer_name}")
        else:
            # msg: These errors indicated a problem with the configuration given to the post-processing API.
            raise SSDPostprocessException("Model params must be provided.")

    def _calculate_new_dims(self, slice_str, shape, activation_feature_reshape):
        np_shape = np.empty(shape)
        reshaped = self._apply_activation_feature_reshape(slice_str, activation_feature_reshape, np_shape)
        return list(reshaped.shape)

    def _apply_activation_feature_reshape(self, activation_feature_slice, reshape_dimensions, tensor):
        shape = list(tensor.shape[:-1])

        shape.extend(reshape_dimensions)
        reshaped = np.reshape(tensor, shape)

        # reorder bbox decoder inputs relative to default order [ty, tx, th ,tw]
        reshaped = reshaped[..., self._config.regression_prediction_order]

        indices = tuple([slice(None, None, None)] * (len(reshaped.shape) - 1) + [activation_feature_slice])
        split_shape = reshaped[indices]
        new_shape_dims = [*list(tensor.shape[:-1]), -1]
        return np.reshape(split_shape, new_shape_dims)

    def _create_new_layer(self, activation, feature_slice, feature_reshape, layer, new_layer_name):
        new_layer = copy.deepcopy(layer)
        new_layer.kernel_shape = self._calculate_new_dims(feature_slice, new_layer.kernel_shape, feature_reshape)

        new_output_shapes = []
        for shape in new_layer.output_shapes:
            new_output_shapes.append([-1, *self._calculate_new_dims(feature_slice, shape[1:], feature_reshape)])

        new_layer.output_shapes = new_output_shapes
        new_layer.add_original_name(layer.name)
        new_layer.activation = activation
        new_layer.name = new_layer_name
        new_layer.index = self._hn.get_next_index()
        return new_layer

    def _adjust_weights(self, index, activation_feature_reshape, weights):
        new_weights = {}
        activation_feature_slice = MULTIPLE_ACTIVATION_FEATURES_SLICES[index]
        for weight_name, weight_value in weights.items():
            if weight_name == "padding_const_value:0":
                new_weights[weight_name] = weight_value
                continue

            new_shape = self._apply_activation_feature_reshape(
                activation_feature_slice,
                activation_feature_reshape,
                weight_value,
            )
            new_weights[weight_name] = new_shape
            logger.debug(
                f"Adjust weight: {weight_name}. Previous shape: {weight_value.shape} -> new shape: {new_shape.shape}",
            )

        return InnerParams(new_weights)

    def _create_new_concat_layer(self, hn_layer):
        new_concat_layer = ConcatLayer()
        new_concat_layer.name = CONCAT_LAYER_NAME_FORMAT.format(hn_layer.name)
        self._hn.push_layer(new_concat_layer, preds=[hn_layer])

    def _create_nms_layer(self, concat_layer_name):
        concat_layer = self._hn.get_layer_by_name(concat_layer_name)
        input_shapes = concat_layer.output_shapes
        nms_layer = NMSLayer()
        nms_layer.name = "nms1"

        classes = self._config.classes - 1 if self._config.background_removal else self._config.classes
        iou_th = self._config.nms_iou_th
        if self._config.nms_iou_th < 1.0 and self._config.input_division_factor != 1:
            if self._engine == PostprocessTarget.NN_CORE:
                logger.warning(
                    "Too many proposals in network, IOU threshold different than 1 isn't supported"
                    " and may lead to missing proposals. It's recommended to use IOU threshold = 1.",
                )
            else:
                # auto flow, nms on chip has to be with iou_th=1.0
                # TODO - Apply IOU op only when division factor is bigger than 1 - SDK-36464
                iou_th = 1.0

        nms_layer.set_nms_params(
            self._config.nms_scores_th,
            iou_th,
            self._config.max_proposals_per_class,
            classes,
            self._config.input_division_factor,
        )
        nms_layer.input_shapes = input_shapes

        nms_output_layer = OutputLayer()
        nms_output_layer.name = f"{concat_layer.scope}/{self.output_original_name}"
        nms_output_layer.input_shapes = input_shapes
        nms_output_layer.original_names = [self.output_original_name]

        self._hn.push_layer(nms_layer, preds=[concat_layer], calc_shapes=False)
        self._hn.push_layer(nms_output_layer, preds=[nms_layer], calc_shapes=False)
        self._hn.net_params.output_layers_order = [nms_layer.name]

        for output in nms_output_layer.outputs:
            output_layer = self._hn.get_layer_by_name(output)
            self._hn.remove_edge(nms_output_layer, output_layer)
            self._hn.remove_node(output_layer)

        nms_output_layer.outputs = []

    def _add_proposal_generator_layers(self):
        bbox_conv_pairs = [
            (bbox[BBoxDecodersInfo.NAME.value], bbox[BBoxDecodersInfo.CLS_LAYER.value])
            for bbox in self._config.bbox_decoders_info
        ]
        proposal_generators = []

        for i, (bbox_decoder_layer_name, conv_layer_name) in enumerate(bbox_conv_pairs):
            conv_layer = self._hn.get_layer_by_name(conv_layer_name)
            bbox_decoder_layer = self._hn.get_layer_by_name(bbox_decoder_layer_name)
            f_bbox = bbox_decoder_layer.input_shapes[0][-1]
            anchors = f_bbox // NMSLayer.BBOX_PER_CHUNK
            classes = conv_layer.output_shape[-1] // anchors

            if self._config.background_removal:
                len_params = self._weights[conv_layer.name]["bias:0"].shape[0]
                background_indices = range(self._config.background_removal_index, len_params, self._config.classes)
                new_bias = np.delete(self._weights[conv_layer.name]["bias:0"], background_indices)
                new_kernel = np.delete(self._weights[conv_layer.name]["kernel:0"], background_indices, axis=-1)
                new_weights = {"bias:0": new_bias, "kernel:0": new_kernel}
                self._weights.remove(conv_layer.name)
                self._weights.add(conv_layer.name, new_weights)

                new_kernel_shape = conv_layer.kernel_shape
                new_kernel_shape[-1] -= len(background_indices)
                new_output_shape = conv_layer.output_shape
                new_output_shape[-1] -= len(background_indices)
                conv_layer.kernel_shape = new_kernel_shape
                conv_layer.output_shapes = [new_output_shape]
                classes -= 1

            conv_layer.activation = self.cls_activation
            proposal_generator_layer_name = f"proposal_generator{i}"
            proposal_generator_layer = ProposalGeneratorLayer()
            proposal_generator_layer.name = proposal_generator_layer_name

            bbox_decoder_input = self._hn.get_layer_by_name(bbox_decoder_layer.inputs[0])
            proposal_generator_layer.input_shapes = [bbox_decoder_input.output_shapes[0], conv_layer.output_shapes[0]]
            self._hn.push_layer(proposal_generator_layer, preds=[bbox_decoder_layer, conv_layer], calc_shapes=False)
            proposal_generator_layer.input_division_factor = self._config.input_division_factor
            proposal_generator_layer.append_to_input_list(bbox_decoder_layer)
            proposal_generator_layer.append_to_input_list(conv_layer)
            proposal_generator_layer.update_output_shapes()

            output_features = proposal_generator_layer.output_shapes[0][-1]
            proposal_generator_weights = {"kernel:0": np.ones(output_features)}
            self._weights.add(proposal_generator_layer.name, proposal_generator_weights)

            logger.debug(f"Created layer {proposal_generator_layer_name}")
            proposal_generators.append(proposal_generator_layer_name)

        new_concat_layers = self._create_concats(proposal_generators)
        while len(new_concat_layers) != 1:
            new_concat_layers = self._create_concats(new_concat_layers)

        self._create_nms_layer(new_concat_layers[0])

    def _create_concats(self, layers):
        new_concat_layers = []
        if len(layers) == 1:
            return layers

        if len(layers) % MAX_CONCAT_LAYERS == 1:
            last_pair = layers[-MIN_CONCAT_LAYERS:]
            layers = layers[:-MIN_CONCAT_LAYERS]
            batches = [
                layers[MAX_CONCAT_LAYERS * i : MAX_CONCAT_LAYERS * (i + 1)]
                for i in range(len(layers) // MAX_CONCAT_LAYERS + 1)
            ]
            batches.append(last_pair)
        else:
            batches = [
                layers[MAX_CONCAT_LAYERS * i : MAX_CONCAT_LAYERS * (i + 1)]
                for i in range(len(layers) // MAX_CONCAT_LAYERS + 1)
            ]

        for batch in batches:
            if batch:
                new_concat_name = CONCAT_LAYER_NAME_FORMAT.format(batch[0])
                input_shapes = []
                f_out = 0
                for layer in batch:
                    hn_layer = self._hn.get_layer_by_name(layer)
                    input_shapes.extend(hn_layer.output_shapes)
                    f_out += hn_layer.output_shapes[0][-1]

                _, h, w, _ = input_shapes[0]
                output_shape = [-1, h, w, f_out]

                new_concat_layer = ConcatLayer()
                new_concat_layer.name = new_concat_name
                new_concat_layer.input_shapes = input_shapes
                new_concat_layer.output_shapes = [output_shape]

                preds = [self._hn.get_layer_by_name(x) for x in batch]
                self._hn.push_layer(new_concat_layer, preds=preds, calc_shapes=False)
                new_concat_layers.append(new_concat_name)
                logger.debug(f"Created layer {new_concat_name}")

        return new_concat_layers

    def _load_and_validate_config_data(self, **kwargs):
        self.output_original_name = kwargs.get("output_original_name", DEFAULT_NMS_OUTPUT_ORIGINAL_NAME)
        nms_iou_th = 1.0 if isinstance(self, NMSCenternetPostProcess) else DEFAULT_IOU_TH
        self._load_json(self.meta_arch, nms_iou_th)
        self._validate_config_data_keys([x.value for x in self.inputs_keys])

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.NN_CORE, **kwargs):
        self._load_and_validate_config_data(**kwargs)
        if engine == PostprocessTarget.NN_CORE:
            self._add_bbox_decoder_layers()
            self._add_proposal_generator_layers()
            self._split_multiple_activation_layers()
            self._hn.calculate_shapes()
            self._add_bbox_decoder_weights()
            self._verify_input_division_factor_value(hw_consts["NMS_ON_NN_CORE_MAX_SIZE"])

            if len(self._hn.net_params.net_scopes) > 1:
                names_mapping = self._hn.add_scopes()
                self._weights = ModelParams(self._weights, names_mapping=names_mapping)

    def _verify_input_division_factor_value(self, nms_max_memory_size):
        # the input shapes of the proposal generators and the concat layer depend on the input_division_value
        # the constraint on the shape is as follows:
        # proposal generator:
        #   ((pred0.height / input_division_factor * pred0.width * pred0.features) +
        #       2 * (pred1.height / input_division_factor* pred1.width * pred1.features)) < nms_max_memory_size
        # concat:
        #   2 * sum(pred[i].width * pred[i].features) <  nms_max_memory_size
        generators_cond = False
        for generator in self._hn.get_layers_by_type(LayerType.proposal_generator):
            shape1_prod = np.prod(generator.input_shapes[0][1:])
            shape2_prod = np.prod(generator.input_shapes[1][1:])
            if (shape1_prod + 2 * shape2_prod) / generator.input_division_factor > nms_max_memory_size:
                generators_cond = True
                break

        concat = next(next(iter(self._hn.predecessors(nms))) for nms in self._hn.get_layers_by_type(LayerType.nms))
        concat_sum = 2 * sum(in_shape[2] * in_shape[3] for in_shape in concat.input_shapes)
        if concat_sum > nms_max_memory_size or generators_cond:
            logger.warning(
                "The input_division_factor value might be too low for the current network. "
                "Please increase the value in the nms json config",
            )

    def _validate_config_data_keys(self, nms_layer_input_keys):
        for layer_param in nms_layer_input_keys:
            bbox_input_layers = {
                bbox[BBoxDecodersInfo.NAME.value]: bbox[layer_param] for bbox in self._config.bbox_decoders_info
            }
            for bbox_name, layer_name in bbox_input_layers.items():
                if not self._hn.get_layer_by_name(layer_name):
                    raise BBoxDecoderInputLayerNotFoundException(
                        f"BBox decoder layer {bbox_name} was configured with layer "
                        f"{layer_name}, which couldn't be found in the input HN "
                        "model.",
                    )

    def add_iou_postprocess_layer_to_hn(self):
        nms_layer = [nms_layer for nms_layer in self.hn.get_real_output_layers() if nms_layer.op == LayerType.nms][-1]
        nms_output_layer = list(self.hn.successors(nms_layer))[-1]
        postprocess_layer = PostprocessLayer()
        output_classes = self.config.classes - 1 if self.config.background_removal else self.config.classes
        postprocess_layer.output_shapes = [
            [-1, output_classes, 5, self.config.max_proposals_per_class],  # 5 is for [y_min, x_min, y_max x_max, score]
        ]

        self.update_new_layer_data(postprocess_layer, nms_layer, "iou_postprocess")

        postprocess_layer.meta_arch = PP_CLASS_TO_META_ARCH[type(self)]
        postprocess_layer.postprocess_type = PostprocessType.IOU
        postprocess_layer.iou_th = self.config.nms_iou_th
        postprocess_layer.classes = self.config.classes
        postprocess_layer.max_proposals_per_class = self.config.max_proposals_per_class
        postprocess_layer.nms_scores_th = self.config.nms_scores_th
        postprocess_layer.op = LayerType.postprocess

        # creates postprocess output layers
        postprocess_output_layer = OutputLayer()
        postprocess_output_layer.original_names = [self.output_original_name]
        postprocess_output_layer.output_shapes = [
            [-1, output_classes, 5, self.config.max_proposals_per_class],  # 5 is for [y_min, x_min, y_max x_max, score]
        ]
        self.update_new_layer_data(postprocess_output_layer, postprocess_layer, self.output_original_name)

        self._hn.net_params.output_layers_order.remove(nms_layer.name)
        self._hn.net_params.output_layers_order.append(postprocess_layer.name)
        self._hn.remove_edge(nms_layer, nms_output_layer)
        self._hn.remove_node(nms_output_layer)

    def update_new_layer_data(self, new_layer, input_layer, layer_name):
        new_layer.name = f"{input_layer.scope}/{layer_name}"
        new_layer.engine = PostprocessTarget.CPU
        new_layer.index = self._hn.get_next_index()
        new_layer.input_shapes = input_layer.output_shapes
        new_layer.inputs.append(input_layer.name)
        new_layer.input_indices.extend([input_layer.index])
        new_layer.outputs.append(input_layer.name)
        self._hn.add_node(new_layer)
        self._hn.add_edge(input_layer, new_layer)

    def _split_multiple_activation_layers(self):
        pass


class CPUPostProcess(NMSPostProcess):
    CHANGE_TO_SIGMOID_ARCHS = [
        NMSMetaArchitectures.SSD,
        NMSMetaArchitectures.YOLOV5,
        NMSMetaArchitectures.YOLOX,
        NMSMetaArchitectures.YOLOV6,
        NMSMetaArchitectures.YOLOV8,
        NMSMetaArchitectures.DAMOYOLO,
    ]

    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.CPU,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)

    def add_postprocess_layer_to_hn(self):
        output_layers = self.hn.get_output_layers()
        postprocess_layer = PostprocessLayer()
        postprocess_layer.meta_arch = self.meta_arch
        postprocess_layer.engine = self._engine
        postprocess_layer.postprocess_type = (
            PostprocessType.NMS if not self._bbox_decoding_only else PostprocessType.BBOX_DECODER
        )
        postprocess_layer.max_total_output_proposals = self.config.max_total_proposals
        postprocess_layer.image_dims = self.config.image_dims
        output_classes = self.config.classes - 1 if self.config.background_removal else self.config.classes
        postprocess_layer.name = f"{output_layers[-1].scope}/{postprocess_layer.meta_arch.value}_nms_postprocess"
        postprocess_layer.index = self._hn.get_next_index()
        postprocess_layer.op = LayerType.postprocess
        postprocess_layer.iou_th = self.config.nms_iou_th
        postprocess_layer.classes = self.config.classes
        postprocess_layer.max_proposals_per_class = self.config.max_proposals_per_class
        postprocess_layer.nms_scores_th = self.config.nms_scores_th
        self.config.name = postprocess_layer.name

        # extracting the conv layers that should be the inputs of the postprocess from the nms config json
        # and connecting them to the postprocess layer
        if self.meta_arch == NMSMetaArchitectures.YOLOV5:
            conv_layers_for_nms = [
                self._hn.get_layer_by_name(bbox_decoders_info[BBoxDecodersInfo.ENCODED_LAYER.value])
                for bbox_decoders_info in self.config.bbox_decoders_info
            ]
        elif self.meta_arch in [
            NMSMetaArchitectures.SSD,
            NMSMetaArchitectures.YOLOX,
            NMSMetaArchitectures.YOLOV8,
            NMSMetaArchitectures.DAMOYOLO,
        ]:
            regression_layers = [
                self._hn.get_layer_by_name(bbox_decoders_info[BBoxDecodersInfo.REG_LAYER.value])
                for bbox_decoders_info in self.config.bbox_decoders_info
            ]
            classes_layers = [
                self._hn.get_layer_by_name(bbox_decoders_info[BBoxDecodersInfo.CLS_LAYER.value])
                for bbox_decoders_info in self.config.bbox_decoders_info
            ]
            concatenated_layers = [regression_layers, classes_layers]
            if self.meta_arch == NMSMetaArchitectures.YOLOX:
                obj_layers = [
                    self._hn.get_layer_by_name(bbox_decoders_info[BBoxDecodersInfo.OBJ_LAYER.value])
                    for bbox_decoders_info in self.config.bbox_decoders_info
                ]
                concatenated_layers.append(obj_layers)

            conv_layers_for_nms = [
                layer[branch_index] for branch_index in range(len(classes_layers)) for layer in concatenated_layers
            ]
        elif self.meta_arch == NMSMetaArchitectures.YOLOV5_SEG:
            conv_layers_for_nms = self._hn.get_real_output_layers()
        else:
            raise NMSConfigPostprocessException(
                f"Currently {self.meta_arch.value} is not fully supported by the CPU",
            )

        for encoded_layer in conv_layers_for_nms:
            encoded_layer.outputs.append(postprocess_layer.name)
            self._hn.add_edge(encoded_layer, postprocess_layer)
            postprocess_layer.input_shapes.extend(encoded_layer.output_shapes)
            postprocess_layer.inputs.append(encoded_layer.name)
            postprocess_layer.append_input_index(encoded_layer.index)
            encoded_layer.output_indices = [postprocess_layer.index]

            # removing nn_core output layers (convs) for inserting the postprocess output layer as the only one
            successors = list(self._hn.successors(encoded_layer))
            output_layer = [successor for successor in successors if successor.op == LayerType.output_layer]
            if len(output_layer) != 1:
                raise NMSConfigPostprocessException(f"The layer {encoded_layer.name} doesn't have one output layer")
            output_layer = output_layer[0]

            encoded_layer.outputs.remove(output_layer.name)
            self._hn.net_params.output_layers_order.remove(encoded_layer.name)
            self._hn.remove_edge(encoded_layer, output_layer)
            self._hn.remove_node(output_layer)

        postprocess_layer.update_output_shapes(
            classes=output_classes,
            max_per_class=self.config.max_proposals_per_class,
        )
        # creates postprocess output layers
        postprocess_output_layer = OutputLayer()
        postprocess_output_layer.name = f"{output_layers[-1].scope}/{self.output_original_name}"
        postprocess_output_layer.engine = self._engine
        postprocess_output_layer.index = self._hn.get_next_index()
        postprocess_output_layer.input_shapes = postprocess_layer.output_shapes
        postprocess_output_layer.output_shapes = postprocess_layer.output_shapes
        postprocess_output_layer.original_names = [self.output_original_name]

        # connects output layer to postprocess layer
        self._hn.add_node(postprocess_output_layer)
        self._hn.add_edge(postprocess_layer, postprocess_output_layer)
        postprocess_output_layer.inputs.extend([postprocess_layer.name])
        postprocess_output_layer.input_indices.extend([postprocess_layer.index])
        postprocess_layer.outputs.append(postprocess_output_layer.name)
        self._hn.net_params.output_layers_order.append(postprocess_layer.name)

    def _update_output_activations_to_sigmoid(self, layers):
        if self.meta_arch not in self.CHANGE_TO_SIGMOID_ARCHS:
            raise NMSConfigPostprocessException(
                f"Shouldn't change activations to sigmoid in {self.meta_arch.value}" "meta arch.",
            )
        for layer in layers:
            if layer.op == LayerType.conv and layer.activation != ActivationType.sigmoid:
                if layer.activation:
                    info_msg = f"The activation function of layer {layer.name} was replaced by a Sigmoid"
                else:
                    info_msg = f"A Sigmoid activation was added to layer {layer.name}"
                logger.info(info_msg)
                layer.activation = ActivationType.sigmoid
                self.sigmoid_layers.append(layer.name)

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.CPU, **kwargs):
        if engine == PostprocessTarget.CPU:
            # not-neural-core flow
            self._load_and_validate_config_data(**kwargs)
            layers_to_sigmoid = kwargs.get("layers_to_sigmoid", None)
            if layers_to_sigmoid:
                self._update_output_activations_to_sigmoid(layers_to_sigmoid)
            self.add_postprocess_layer_to_hn()
        else:
            super().prepare_hn_and_weights(hw_consts, engine, **kwargs)


class NMSYOLOV8PostProcess(CPUPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.CPU,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self.meta_arch = NMSMetaArchitectures.YOLOV8
        self.inputs_keys = [BBoxDecodersInfo.CLS_LAYER, BBoxDecodersInfo.REG_LAYER]

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.CPU, **kwargs):
        has_combined_output = all(
            BBoxDecodersInfo.COMBINED_LAYER.value in bbox_decoder for bbox_decoder in self._config_file["bbox_decoders"]
        )
        if has_combined_output:
            # nanodet arch, slices have to be added to the network
            self._add_slices_to_combined_layers()
            kwargs["layers_to_sigmoid"] = []
        else:
            kwargs["layers_to_sigmoid"] = [
                self._hn.get_layer_by_name(bbox_decoder[BBoxDecodersInfo.CLS_LAYER.value])
                for bbox_decoder in self._config_file["bbox_decoders"]
            ]

        dfl_on_nn_core = kwargs.get("dfl_on_nn_core", False)
        if dfl_on_nn_core:
            # performs DFL on the nn_core (softmax expectation approximation for the boxes)
            self._add_boxes_dfl_layers_to_hn()

        super().prepare_hn_and_weights(hw_consts, engine, **kwargs)

    def _load_and_validate_config_data(self, **kwargs):
        super()._load_and_validate_config_data(**kwargs)
        self.config.dfl_on_nn_core = kwargs.get("dfl_on_nn_core", False)

    def _add_slices_to_combined_layers(self):
        """
        this method slices the output convs of nanodet (based on yolov8 postprocess)
        to two branches each conv of class prediction and regression.
        And applies sigmoid to the cls branch.
        """
        output_combined_layers = [
            self._hn.get_layer_by_name(bbox_decoder[BBoxDecodersInfo.COMBINED_LAYER.value])
            for bbox_decoder in self._config_file["bbox_decoders"]
        ]

        for i, combined_layer in enumerate(output_combined_layers):
            output_layer = [
                output_layer
                for output_layer in self._hn.successors(combined_layer)
                if output_layer.op == LayerType.output_layer
            ]
            if len(output_layer) != 1:
                raise NMSConfigPostprocessException(f"Combined layer {combined_layer.name} must have an output layer.")

            output_layer = output_layer[0]
            classes = self._config_file[NMSProperties.CLASSES.value]
            regression_length = self._config_file[NMSProperties.REGRESSION_LENGTH.value]

            # creates slice layer for new each branch - classes and regression.
            # the first #classes features are for the classes branch and the rest are for the regression
            slice_sizes = [(0, classes), (classes, classes + regression_length * 4)]
            for j, silce_size in enumerate(slice_sizes):
                slice_layer = FusedSliceLayer()
                slice_layer.name = f"{combined_layer.scope}/slice_{i}_{j}_{combined_layer.name_without_scope}"
                slice_layer.output_shapes = [[*combined_layer.output_shape[:-1], slice_sizes[j][1] - slice_sizes[j][0]]]
                slice_layer.index = self._hn.get_next_index()
                slice_layer.original_names = combined_layer.original_names
                slice_layer.height_slice = [0, combined_layer.output_shape[1], 1]
                slice_layer.width_slice = [0, combined_layer.output_shape[2], 1]
                slice_layer.features_slice = [silce_size[0], silce_size[1], 1]

                slice_layer.inputs = [combined_layer.name]
                slice_layer.input_indices = [combined_layer.index]
                slice_layer.input_shapes = combined_layer.input_shapes
                self._hn.add_node(slice_layer)
                self._hn.add_edge(combined_layer, slice_layer)

                if j:
                    # regression branch, applies only slice layer
                    del self._config_file["bbox_decoders"][i][BBoxDecodersInfo.COMBINED_LAYER.value]
                    self._config_file["bbox_decoders"][i][BBoxDecodersInfo.REG_LAYER.value] = slice_layer.name

                    combined_layer.outputs.append(slice_layer.name)
                    combined_layer.output_shapes.append(combined_layer.output_shape)
                    combined_layer.output_indices.append(slice_layer.index)
                    output_pred_layer = slice_layer
                else:
                    # cls branch, applies sigmoid
                    combined_layer.replace_output_layer(output_layer.name, slice_layer.name)
                    combined_layer.replace_output_index(output_layer.index, slice_layer.index)

                    sigmoid_layer = FusedStandaloneActivationLayer()
                    sigmoid_layer.index = self._hn.get_next_index()
                    sigmoid_layer.name = f"{output_layer.scope}/slice_{i}_{j}_sigmoid_layer"
                    sigmoid_layer.activation = ActivationType.sigmoid
                    self._hn.push_layer(sigmoid_layer, [slice_layer])
                    output_pred_layer = sigmoid_layer
                    self._config_file["bbox_decoders"][i][BBoxDecodersInfo.CLS_LAYER.value] = output_pred_layer.name

                # applies output layer on the slice/sigmoid
                slice_output_layer = OutputLayer()
                slice_output_layer_name = f"{self.output_original_name}_slice_{i}_{j}"
                slice_output_layer.name = f"{output_layer.scope}/{slice_output_layer_name}"
                slice_output_layer.index = self._hn.get_next_index()
                slice_output_layer.output_shapes = output_pred_layer.output_shapes
                slice_output_layer.original_names = [slice_output_layer_name]

                # connects output layer
                self._hn.add_node(slice_output_layer)
                self._hn.add_edge(output_pred_layer, slice_output_layer)
                slice_output_layer.inputs.extend([output_pred_layer.name])
                slice_output_layer.input_indices.extend([output_pred_layer.index])
                output_pred_layer.outputs.append(slice_output_layer.name)
                self._hn.net_params.output_layers_order.append(output_pred_layer.name)

            self._hn.net_params.output_layers_order.remove(combined_layer.name)
            self._hn.remove_edge(combined_layer, output_layer)
            self._hn.remove_node(output_layer)

    def _add_boxes_dfl_layers_to_hn(self):
        """
        This function adds the boxes DFL layers to the HN network.
        those layers perform softmax expectation approximation for the boxes, by applying the following operations:
        y_min/ x_min/ y_max/ x_max = E(softmax(X_MIN)) where X_MIN/Y_MIN/X_MAX/Y_MAX is a vector in length of the regression length.
        """

        reg_layers = [
            self.hn.get_layer_by_name(bbox_decoder[BBoxDecodersInfo.REG_LAYER.value])
            for bbox_decoder in self._config_file["bbox_decoders"]
        ]
        for i, reg_layer in enumerate(reg_layers):
            base_index = self._hn.get_next_index()

            # creates softmax layer
            softmax = SoftmaxLayer()
            softmax.name = f"{reg_layer.name}_dfl_softmax{base_index}"
            groups = 4  # 4 groups for y_min, x_min, y_max, x_max
            softmax.groups = groups
            softmax.index = base_index
            softmax.original_names = reg_layer.original_names.copy()
            softmax.input_shapes = reg_layer.input_shapes.copy()
            softmax.output_shapes = reg_layer.input_shapes.copy()
            base_index += 1

            # creates reduce mean operation
            group_conv = FusedConv2DLayer()
            group_conv.name = f"{reg_layer.name}_dfl_reduce_mean{base_index}"
            group_conv.index = base_index
            group_conv.op = LayerType.conv
            group_conv.kernel_shape = [1, 1, reg_layer.output_shapes[0][-1], 4]
            kernel = np.concatenate(
                [
                    np.concatenate(
                        [
                            np.reshape(np.arange(self._config_file["regression_length"]), (1, 1, -1, 1))
                            for _ in range(1)
                        ],
                        axis=-2,
                    )
                    for _ in range(groups)
                ],
                axis=-1,
            )
            group_conv.kernel = kernel
            self.weights[f"{group_conv.name}/kernel:0"] = kernel
            self.weights[f"{group_conv.name}/bias:0"] = np.zeros((1, 1, 4))
            group_conv.dilations = [1] * 4
            group_conv.bn_enabled = False
            group_conv.strides = [1] * 4
            group_conv.padding = PaddingType.SAME
            group_conv.groups = groups
            group_conv.original_names = reg_layer.original_names.copy()
            group_conv.input_shapes = reg_layer.input_shapes.copy()
            group_conv.output_shapes = [[-1, *reg_layer.input_shapes[0][1:3], groups]]

            # connects the layers
            succ = next(iter(self._hn.successors(reg_layer)))  # has only one output
            self._fuser_helper.replace_succ(reg_layer, succ, softmax)
            self._fuser_helper.add_succs(softmax, [group_conv], update_output_shapes=False)
            self._fuser_helper.add_succs(group_conv, [succ], update_output_shapes=False)

            self._fuser_helper.add_preds(softmax, [reg_layer], update_input_shapes=False)
            self._fuser_helper.add_preds(group_conv, [softmax], update_input_shapes=False)
            self._fuser_helper.replace_pred(succ, reg_layer, group_conv)
            self._hn.net_params.output_layers_order[self._hn.net_params.output_layers_order.index(reg_layer.name)] = (
                group_conv.name
            )

            self._config_file["bbox_decoders"][i][BBoxDecodersInfo.REG_LAYER.value] = group_conv.name
        self._config_file["regression_length"] = 1


class NMSDAMOYOLOPostProcess(NMSYOLOV8PostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.CPU,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self.meta_arch = NMSMetaArchitectures.DAMOYOLO

    def _load_json(self, meta_arch, nms_iou_th=DEFAULT_IOU_TH):
        # damoyolo uses the same schema as yolov8
        self._config = NMSConfig.from_json(
            self._config_file,
            meta_arch,
            nms_iou_th,
            schema_filename=f"./nms_{NMSMetaArchitectures.YOLOV8.value}_config.schema.json",
        )
        if not self._enforce_iou_threshold and self._config.input_division_factor != 1:
            logger.debug("IOU threshold set to 1 in nms postprocess")
            self._config.nms_iou_th = 1.0


class NMSSSDPostProcess(CPUPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.NN_CORE,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self._multiple_activation_layers = []
        self.meta_arch = NMSMetaArchitectures.SSD
        self.inputs_keys = [BBoxDecodersInfo.REG_LAYER, BBoxDecodersInfo.CLS_LAYER]

    def _divide_weights(self, activation, weights):
        new_weights = dict(weights)
        if activation == ActivationType.linear:
            new_weights["kernel:0"] /= self._config.centers_scale_factor
            new_weights["bias:0"] /= self._config.centers_scale_factor
        elif activation == ActivationType.exp:
            new_weights["kernel:0"] /= self._config.bbox_dimensions_scale_factor
            new_weights["bias:0"] /= self._config.bbox_dimensions_scale_factor
        return InnerParams(new_weights)

    def _split_multiple_activation_layers(self):
        """
        Go over the HN and split a layer that contains multiple activation values in param argument.
        The result of the function is a modified HN with the new singular-activation layers instead of the multiple
            activation layer, and the new reshaped weights (if params are given). If a multiple activation layer is not
            followed by a bbox decoder layer or concat layer, a new concat layer will be added after the layer split.
        """
        for hn_layer in self._multiple_activation_layers:
            for output_layer_name in hn_layer.outputs:
                output_layer = self._hn.get_layer_by_name(output_layer_name)
                if output_layer.op not in [LayerType.bbox_decoder, LayerType.concat]:
                    logger.debug(f"Output layer of {hn_layer.name} is not bbox_decoder/concat. Adding concat_layer")
                    self._create_new_concat_layer(hn_layer)

            if self._weights:
                weights_layer = self._weights[hn_layer.name]

            features = hn_layer.output_shape[-1] // MULTIPLE_ACTIVATION_FEATURES_RESHAPE_FACTOR
            feature_reshape = [features, MULTIPLE_ACTIVATION_FEATURES_RESHAPE_FACTOR]

            new_layers = []
            layers_to_remove = set()
            for index, activation in enumerate(MULTIPLE_ACTIVATION_OPS):
                new_layer_name = NEW_LAYER_NAME_FORMAT.format(hn_layer.name, activation.value, index)

                layers_to_remove.add(hn_layer)
                new_layer = self._create_new_layer(
                    activation,
                    MULTIPLE_ACTIVATION_FEATURES_SLICES[index],
                    feature_reshape,
                    hn_layer,
                    new_layer_name,
                )

                self._hn.add_node(new_layer)
                self._hn.adjust_new_layer_input_output(hn_layer, new_layer)
                new_layers.append(new_layer)

                if self._weights:
                    new_weights = self._adjust_weights(index, feature_reshape, weights_layer)
                    new_weights = self._divide_weights(activation, new_weights)
                    self._weights.add(new_layer.name, new_weights)
                    self._weights.remove(hn_layer.name)

                logger.debug(f"Created new layer: {new_layer.name}")

            for layer in layers_to_remove:
                self._hn.remove_layer(layer)

    def _add_bbox_decoder_layers(self):
        for bbox in self._config.bbox_decoders_info:
            conv_layer = self._hn.get_layer_by_name(bbox[BBoxDecodersInfo.REG_LAYER.value])
            self._multiple_activation_layers.append(conv_layer)
            bbox_decoder_layer = BboxDecoderLayer()
            bbox_decoder_layer.name = bbox[BBoxDecodersInfo.NAME.value]
            bbox_decoder_layer.input_shapes = conv_layer.output_shapes

            self._hn.push_layer(bbox_decoder_layer, preds=[conv_layer], calc_shapes=False)
            logger.debug(f"Added BBox Decoder layer {bbox_decoder_layer.name}")

    def _get_bbox_decoders_list(self):
        anchors = self._config.anchors
        bbox_list = []
        for anchor_name in sorted(anchors):
            w = np.array(list(zip(*anchors[anchor_name])), dtype=np.float64)
            bbox_list.append(BBoxDecoderParamsSSD(anchor_name, w[0], w[1]))
        return bbox_list

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.CPU, **kwargs):
        # ensures the output activation of the class prediction is sigmoid if not change it to sigmoid
        self._update_output_activations_to_sigmoid()
        super().prepare_hn_and_weights(hw_consts, engine, **kwargs)

    def _update_output_activations_to_sigmoid(self):
        cls_layers_to_num_anchors = {
            self._hn.get_layer_by_name(bbox[BBoxDecodersInfo.CLS_LAYER.value]): len(bbox[BBoxDecodersInfo.H.value])
            for bbox in self._config_file["bbox_decoders"]
        }
        for cls_layer in cls_layers_to_num_anchors:
            num_of_anchors = cls_layers_to_num_anchors[cls_layer]
            num_of_classes = self._config_file["classes"]
            if cls_layer.op == LayerType.conv and cls_layer.activation != ActivationType.sigmoid:
                if cls_layer.output_shape[-1] == num_of_anchors * num_of_classes:
                    if cls_layer.activation:
                        info_msg = f"The activation function of layer {cls_layer.name} was replaced by a Sigmoid"
                    else:
                        info_msg = f"A Sigmoid activation was added to layer {cls_layer.name}"
                    logger.info(info_msg)
                    cls_layer.activation = ActivationType.sigmoid
                    self.sigmoid_layers.append(cls_layer.name)
                else:
                    raise NMSConfigPostprocessException(
                        "Failed to identify class predictions layers, please check your NMS config json",
                    )


class NMSCenternetPostProcess(NMSPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.NN_CORE,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self._cls_activation = ActivationType.relu
        self.meta_arch = NMSMetaArchitectures.CENTERNET
        self.inputs_keys = [BBoxDecodersInfo.REG_LAYER_H, BBoxDecodersInfo.REG_LAYER_W, BBoxDecodersInfo.CLS_LAYER]

    def _add_bbox_decoder_layers(self):
        for bbox in self._config.bbox_decoders_info:
            conv_layer_h = self._hn.get_layer_by_name(bbox[BBoxDecodersInfo.REG_LAYER_H.value])
            conv_layer_w = self._hn.get_layer_by_name(bbox[BBoxDecodersInfo.REG_LAYER_W.value])
            dx_dy_weights_name = conv_layer_h.name
            w_h_weights_name = conv_layer_w.name
            # dx <--> dy
            self._weights[dx_dy_weights_name]["kernel:0"][:, :, :, [0, 1]] = self._weights[dx_dy_weights_name][
                "kernel:0"
            ][:, :, :, [1, 0]]
            self._weights[dx_dy_weights_name]["bias:0"][[0, 1]] = self._weights[dx_dy_weights_name]["bias:0"][[1, 0]]
            # h <--> w
            self._weights[w_h_weights_name]["kernel:0"][:, :, :, [0, 1]] = self._weights[w_h_weights_name]["kernel:0"][
                :,
                :,
                :,
                [1, 0],
            ]
            self._weights[w_h_weights_name]["bias:0"][[0, 1]] = self._weights[w_h_weights_name]["bias:0"][[1, 0]]

            bbox_decoder_layer = BboxDecoderLayer()
            bbox_decoder_layer.name = bbox[BBoxDecodersInfo.NAME.value]
            bbox_decoder_layer.input_shapes = conv_layer_h.output_shapes
            bbox_decoder_layer.output_shapes = conv_layer_h.output_shapes
            bbox_decoder_layer.output_shapes[0][-1] += conv_layer_w.output_shapes[0][-1]

            self._hn.push_layer(bbox_decoder_layer, preds=[conv_layer_h, conv_layer_w], calc_shapes=False)
            logger.debug(f"Added BBox Decoder layer {bbox_decoder_layer.name}")

    def _get_bbox_decoders_list(self):
        anchors = self._config.anchors
        bbox_list = []
        for anchor_name in sorted(anchors):
            w = np.array(list(zip(*anchors[anchor_name])), dtype=np.float64)
            bbox_list.append(BBoxDecoderParamsCenternet(anchor_name, w[0], w[1]))
        return bbox_list


class NMSYOLOV5PostProcess(CPUPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.NN_CORE,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self._encoded_layers = []
        self.meta_arch = NMSMetaArchitectures.YOLOV5
        self.inputs_keys = [BBoxDecodersInfo.ENCODED_LAYER]

    def _add_feature_multiplier_layer(self):
        encoded_layers = [bbox[BBoxDecodersInfo.ENCODED_LAYER.value] for bbox in self._config.bbox_decoders_info]
        for conv_layer_name in encoded_layers:
            conv_layer = self._hn.get_layer_by_name(conv_layer_name)
            input_features = conv_layer.output_shapes[0][-1]
            self._encoded_layers.append(conv_layer)
            num_of_anchors = len(next(iter(self._config.anchors.values())))
            num_of_objs = num_of_anchors

            classes = self._config.classes

            feature_multiplier_layer_name = conv_layer_name.replace("conv", "feature_multiplier")

            feature_multiplier_layer = FeatureMultiplierLayer()
            feature_multiplier_layer.name = feature_multiplier_layer_name
            feature_multiplier_layer.input_shapes = conv_layer.output_shapes
            feature_multiplier_layer.output_shapes = [
                conv_layer.output_shapes[0][0:-1] + [num_of_anchors * NMSLayer.BBOX_PER_CHUNK],
            ]
            feature_multiplier_layer.output_shapes.append(
                conv_layer.output_shapes[0][0:-1]
                + [input_features - num_of_anchors * NMSLayer.BBOX_PER_CHUNK - num_of_objs],
            )
            feature_multiplier_layer.feature_multiplier_type = FeatureMultiplierType.yolov5
            yolo_power_table = feature_multiplier_layer.init_power_table(
                feature_multiplier_layer.yolov5(num_of_anchors, classes),
            )
            power_table = {"power_table:0": yolo_power_table}

            self._hn.push_layer(feature_multiplier_layer, preds=[conv_layer], calc_shapes=False)
            self._weights.add(feature_multiplier_layer.name, power_table)

            logger.debug(f"Added feature multiplier layer {feature_multiplier_layer_name}")

    def _add_bbox_decoder_layers(self):
        for bbox in self._config.bbox_decoders_info:
            conv_layer_decoded = self._hn.get_layer_by_name(bbox[BBoxDecodersInfo.ENCODED_LAYER.value])
            feature_multiplier_layer_name = conv_layer_decoded.outputs[0]
            feature_multiplier_layer = self._hn.get_layer_by_name(feature_multiplier_layer_name)
            bbox_decoder_layer = BboxDecoderLayer()
            bbox_decoder_layer.name = bbox[BBoxDecodersInfo.NAME.value]
            bbox_decoder_layer.input_shapes = [feature_multiplier_layer.output_shapes[0]]
            bbox_decoder_layer.output_shapes = bbox_decoder_layer.input_shapes

            self._hn.push_layer(bbox_decoder_layer, preds=[feature_multiplier_layer], calc_shapes=False)
            logger.debug(f"Added BBox Decoder layer {bbox_decoder_layer.name}")

    def _add_proposal_generator_layers(self):
        bbox_conv_pairs = [
            (bbox[BBoxDecodersInfo.NAME.value], bbox[BBoxDecodersInfo.ENCODED_LAYER.value])
            for bbox in self._config.bbox_decoders_info
        ]
        bbox_mult_split_pairs = [
            (bbox_conv_pair[0], self._hn.get_layer_by_name(bbox_conv_pair[1]).outputs[0])
            for bbox_conv_pair in bbox_conv_pairs
        ]
        proposal_generators = []

        for i, (bbox_decoder_layer_name, mult_split_layer_name) in enumerate(bbox_mult_split_pairs):
            mult_split_layer = self._hn.get_layer_by_name(mult_split_layer_name)
            bbox_decoder_layer = self._hn.get_layer_by_name(bbox_decoder_layer_name)
            proposal_generator_layer_name = f"proposal_generator{i}"
            proposal_generator_layer = ProposalGeneratorLayer()
            proposal_generator_layer.name = proposal_generator_layer_name
            self._hn.push_layer(proposal_generator_layer, preds=[bbox_decoder_layer], calc_shapes=False)
            self._hn.add_edge(mult_split_layer, proposal_generator_layer)
            proposal_generator_layer.input_shapes = [
                bbox_decoder_layer.output_shape,
                mult_split_layer.output_shapes[-1],
            ]
            proposal_generator_layer.output_shapes = [
                bbox_decoder_layer.output_shape,
                mult_split_layer.output_shapes[-1],
            ]
            proposal_generator_layer.input_indices.append(mult_split_layer.index)
            proposal_generator_layer.inputs.append(mult_split_layer.name)
            mult_split_layer.outputs.append(proposal_generator_layer.name)
            mult_split_layer.output_indices = [bbox_decoder_layer.index, proposal_generator_layer.index]
            proposal_generator_layer.input_division_factor = self._config.input_division_factor
            proposal_generator_layer.append_to_input_list(bbox_decoder_layer)
            proposal_generator_layer.append_to_input_list(mult_split_layer)
            proposal_generator_layer.update_output_shapes()
            logger.debug(f"Created layer {proposal_generator_layer_name}")
            proposal_generators.append(proposal_generator_layer_name)

        new_concat_layers = self._create_concats(proposal_generators)
        while len(new_concat_layers) != 1:
            new_concat_layers = self._create_concats(new_concat_layers)
        self._create_nms_layer(new_concat_layers[0])

    def _get_bbox_decoders_list(self):
        anchors = self._config.anchors
        anchors_strides = self._config.anchors_stride
        bbox_list = []
        for anchor_name in sorted(anchors):
            w = np.array(list(zip(*anchors[anchor_name])), dtype=np.float64)
            anchors_stride = anchors_strides[anchor_name]
            bbox_list.append(BBoxDecoderParamsYOLOv5(anchor_name, w[0], w[1], anchors_stride))
        return bbox_list

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.NN_CORE, **kwargs):
        layers_to_sigmoid = [self._hn.get_layer_by_name(layer.inputs[0]) for layer in self.hn.get_output_layers()]
        self._update_output_activations_to_sigmoid(layers_to_sigmoid)
        if engine == PostprocessTarget.CPU:
            # not-neural-core flow
            super().prepare_hn_and_weights(hw_consts, engine, **kwargs)
        else:
            # on-chip/ hybrid flow
            self._load_and_validate_config_data(**kwargs)
            # verify yolov5 has only 3 bbox decoders
            self._verify_nn_core_bbox_num(engine)
            self._add_feature_multiplier_layer()
            self._add_bbox_decoder_layers()
            self._add_proposal_generator_layers()
            self._hn.calculate_shapes()
            self._add_bbox_decoder_weights()
            self._verify_input_division_factor_value(hw_consts["NMS_ON_NN_CORE_MAX_SIZE"])
            if engine == PostprocessTarget.AUTO:
                self.add_iou_postprocess_layer_to_hn()

    def _verify_nn_core_bbox_num(self, engine):
        bbox_num = len(self._config.bbox_decoders_info)
        if engine == PostprocessTarget.NN_CORE and bbox_num > YOLOV5_NN_CORE_SUPPORTED_BRANCHES:
            raise YOLOv5PostprocessException(
                f"Unsupported number of bbox decoders {bbox_num}. Please run the NMS "
                "postprocess on the CPU instead by using nms_postprocess(..., engine=cpu)",
            )


class NMSYOLOXPostProcess(CPUPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.CPU,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self._encoded_layers = []
        self.meta_arch = NMSMetaArchitectures.YOLOX
        self.inputs_keys = [BBoxDecodersInfo.REG_LAYER, BBoxDecodersInfo.OBJ_LAYER, BBoxDecodersInfo.CLS_LAYER]

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.CPU, **kwargs):
        kwargs["layers_to_sigmoid"] = [
            self._hn.get_layer_by_name(bbox_decoder[key])
            for bbox_decoder in self._config_file["bbox_decoders"]
            for key in bbox_decoder
            if key in [BBoxDecodersInfo.OBJ_LAYER.value, BBoxDecodersInfo.CLS_LAYER.value]
        ]
        super().prepare_hn_and_weights(hw_consts, engine, **kwargs)


class NMSYOLOV5SegPostProcess(CPUPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.CPU,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self._encoded_layers = []
        self.meta_arch = NMSMetaArchitectures.YOLOV5_SEG
        self.inputs_keys = [BBoxDecodersInfo.ENCODED_LAYER]


class NMSYOLOV6PostProcess(NMSPostProcess):
    def __init__(
        self,
        hn,
        weights,
        config_file,
        engine=PostprocessTarget.NN_CORE,
        enforce_iou_threshold=True,
        bbox_decoding_only=False,
    ):
        super().__init__(hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
        self.meta_arch = NMSMetaArchitectures.YOLOV6
        self.inputs_keys = [BBoxDecodersInfo.REG_LAYER, BBoxDecodersInfo.CLS_LAYER]

    def _add_fused_bbox_decoder_layers(self):
        bbox_conv = [
            (
                bbox[BBoxDecodersInfo.NAME.value],
                bbox[BBoxDecodersInfo.REG_LAYER.value],
                bbox[BBoxDecodersInfo.CLS_LAYER.value],
            )
            for bbox in self._config.bbox_decoders_info
        ]
        fused_bbox_decoders = []

        for bbox_name, reg_layer_name, cls_layer_name in bbox_conv:
            cls_layer = self._hn.get_layer_by_name(cls_layer_name)
            reg_layer = self._hn.get_layer_by_name(reg_layer_name)
            fused_bbox_layer_name = bbox_name
            fused_bbox_decoder_layer = FusedBboxDecoderLayer()
            fused_bbox_decoder_layer.name = fused_bbox_layer_name

            self._hn.push_layer(fused_bbox_decoder_layer, preds=[reg_layer], calc_shapes=False)
            self._hn.add_edge(cls_layer, fused_bbox_decoder_layer)

            fused_bbox_decoder_layer.input_shapes = [reg_layer.output_shape, cls_layer.output_shapes[-1]]
            fused_bbox_decoder_layer.output_shapes = [reg_layer.output_shape, cls_layer.output_shapes[-1]]
            fused_bbox_decoder_layer.input_indices.append(cls_layer.index)
            fused_bbox_decoder_layer.inputs.append(cls_layer.name)
            fused_bbox_decoder_layer.append_to_input_list(reg_layer)
            fused_bbox_decoder_layer.append_to_input_list(cls_layer)
            fused_bbox_decoder_layer.update_output_shapes()
            fused_bbox_decoders.append(fused_bbox_layer_name)

            cls_layer.outputs.append(fused_bbox_decoder_layer.name)
            cls_layer.output_indices = [fused_bbox_decoder_layer.index, fused_bbox_decoder_layer.index]

            # removes nn_core output layers (convs)
            successors = list(self._hn.successors(cls_layer))
            output_layer = [successor for successor in successors if successor.op == LayerType.output_layer]
            if len(output_layer) == 1:
                output_layer = output_layer[0]
                cls_layer.outputs.remove(output_layer.name)
                self._hn.net_params.output_layers_order.remove(cls_layer.name)
                self._hn.remove_edge(cls_layer, output_layer)
                self._hn.remove_node(output_layer)
            logger.debug(f"Added BBox Decoder layer {fused_bbox_decoder_layer.name}")

        new_concat_layers = self._create_concats(fused_bbox_decoders)
        while len(new_concat_layers) != 1:
            new_concat_layers = self._create_concats(new_concat_layers)
        self._create_nms_layer(new_concat_layers[0])

    def _get_bbox_decoders_list(self):
        anchors_strides = self._config.anchors_stride
        return [
            BBoxDecoderParamsYOLOv6(anchor_name, anchors_strides[anchor_name]) for anchor_name in self._config.anchors
        ]

    def prepare_hn_and_weights(self, hw_consts, engine=PostprocessTarget.NN_CORE, **kwargs):
        self._load_and_validate_config_data(**kwargs)
        self._update_output_activations_to_sigmoid()
        self._add_fused_bbox_decoder_layers()
        self._hn.calculate_shapes()
        self._add_bbox_decoder_weights()

    def _update_output_activations_to_sigmoid(self):
        cls_layers_names = [bbox_decoder_info["cls_layer"] for bbox_decoder_info in self._config.bbox_decoders_info]
        for cls_layer_name in cls_layers_names:
            cls_layer = self._hn.get_layer_by_name(cls_layer_name)
            if cls_layer.op == LayerType.conv and cls_layer.activation != ActivationType.sigmoid:
                if cls_layer.activation:
                    info_msg = f"The activation function of layer {cls_layer.name} was replaced by a Sigmoid"
                else:
                    info_msg = f"A Sigmoid activation was added to layer {cls_layer.name}"
                logger.info(info_msg)
                cls_layer.activation = ActivationType.sigmoid
                self.sigmoid_layers.append(cls_layer.name)


PP_CLASS_TO_META_ARCH = {
    NMSSSDPostProcess: NMSMetaArchitectures.SSD,
    NMSCenternetPostProcess: NMSMetaArchitectures.CENTERNET,
    NMSYOLOV5PostProcess: NMSMetaArchitectures.YOLOV5,
    NMSYOLOXPostProcess: NMSMetaArchitectures.YOLOX,
    NMSYOLOV5SegPostProcess: NMSMetaArchitectures.YOLOV5_SEG,
    NMSYOLOV6PostProcess: NMSMetaArchitectures.YOLOV6,
    NMSYOLOV8PostProcess: NMSMetaArchitectures.YOLOV8,
    NMSDAMOYOLOPostProcess: NMSMetaArchitectures.DAMOYOLO,
}


def create_nms_postprocess(
    hn,
    weights,
    config_file,
    engine,
    hw_consts,
    meta_arch,
    enforce_iou_threshold=True,
    bbox_decoding_only=False,
    dfl_on_nn_core=False,
    output_original_name=None,
):
    pp_class = [k for k, v in PP_CLASS_TO_META_ARCH.items() if v == meta_arch]
    if not pp_class:
        raise UnsupportedMetaArchError(f"A postprocess class wasn't found for {meta_arch.value} meta architecture.")
    pp_creator = pp_class[0](hn, weights, config_file, engine, enforce_iou_threshold, bbox_decoding_only)
    pp_creator.prepare_hn_and_weights(
        hw_consts, engine, dfl_on_nn_core=dfl_on_nn_core, output_original_name=output_original_name
    )
    return pp_creator
