import copy
from abc import ABC, abstractmethod
from collections import OrderedDict

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    LayerHandlerType,
    PostprocessTarget,
    PrecisionMode,
    QuantizationAlgorithms,
    QuantizationDeprecatedParam,
)
from hailo_sdk_common.compatibility import integer_types
from hailo_sdk_common.hailo_nn.exceptions import HailoNNLayerParamsException, ProtobufExportError, UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import (
    INDEX_NOT_SET,
    MODEL_SCRIPT_LAYER_PREFIX,
    get_act_short_description,
    get_groups_short_description,
)
from hailo_sdk_common.hailo_nn.hn_layers_params import CompilationParams, DefuseParams, QuantizationParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.tools.models_translator_helper import valid_orig_name


class Layer(ABC):
    # Require weights in order to make the weights assertions more strict
    _REQUIRES_NATIVE_WEIGHTS = None
    _REQUIRES_QUANTIZED_WEIGHTS = None
    _IS_REAL_LAYER = False
    _IS_RANK3_SUPPORTED = False
    next_insertion_order = 0
    ORIGINAL_NAME_LIMIT = 1000
    ORIGINAL_NAME_LIST_LIMIT = 10

    def __init__(self):
        self._logger = default_logger()
        self.index = INDEX_NOT_SET
        self._name = None
        self._original_names = []
        self._input_shapes = []
        self._output_shapes = []
        self._input_layers = []
        self._output_layers = []
        self._input_indices = []
        self._output_indices = []
        self._fused_indices = []
        self._input_vertex_order = []
        self._inputs_by_vertex_order = {}
        self._number_of_inputs_supported = 1
        self._op = None
        self._is_artificial = False
        self._output_copies = 1
        self._precision_config = LayerPrecisionConfig()
        self._translation_config = LayerTranslationConfig()
        self._compilation_params = CompilationParams()
        self.set_default_compilation_params()
        self._defuse_params = DefuseParams()
        self.set_default_defuse_params()
        self._hash = Layer.next_insertion_order
        self._insertion_order = Layer.next_insertion_order
        self._dynamic_weights = False
        Layer.next_insertion_order += 1
        self._name_without_scope = None
        self._scope = None
        self._transposed = False
        self._engine = PostprocessTarget.NN_CORE
        self._group_sizes = None
        self._in_emulation_graph = True
        self._block_info = None
        self._base_layer = None

    @property
    def precision_config(self) -> LayerPrecisionConfig:
        return self._precision_config

    @property
    def translation_config(self) -> LayerTranslationConfig:
        return self._translation_config

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None):
        layer = cls()
        layer.add_original_name(original_name)
        layer.input_vertex_order = input_vertex_order
        if output_shapes:
            layer.output_shapes = output_shapes
            for shape in layer.output_shapes:
                if len(shape) == 3:
                    if cls._IS_RANK3_SUPPORTED:
                        shape.insert(1, 1)
                        layer._logger.debug(
                            f"The output shape of node {original_name} has been changed from {[shape[0], shape[2], shape[3]]} to {shape}, "
                            "assuming an image with a single row.",
                        )
                    else:
                        raise UnsupportedModelError(
                            f"1D form is not supported in layer {original_name} of type {cls.__name__}.",
                        )
        return layer

    def _get_shape_info(self):
        output = ""
        if self.output_shapes is not None:
            output += f", OutShapes={self.output_shapes!s}"
        if self.input_shapes is not None:
            output += f", InShapes={self.input_shapes!s}"
        return output

    def __str__(self):
        orig_name = self._original_names[-1] if self._original_names else None
        description = (
            f"{self.name}(index={self.index}, hash={hash(self):d}, type={type(self).__name__}, "
            f"original_name={orig_name}, op={self._op.value}){self._get_shape_info()}"
        )
        if hasattr(self, "_kernel_shape") and self._kernel_shape is not None:
            description += f", KernelShape={self._kernel_shape}"
        if hasattr(self, "_strides") and self._strides is not None:
            description += f", Strides={self._strides}"
        if hasattr(self, "_padding") and self._padding is not None:
            description += f", Padding={self._padding.value.lower()}"
        if hasattr(self, "_dilations") and self._dilations is not None:
            description += f", Dilations={self._dilations}"
        if hasattr(self, "_method"):
            description += f", Method={self._method.value}"
        if hasattr(self, "_resize_bilinear_pixels_mode"):
            description += f", PixelsMode={self._resize_bilinear_pixels_mode.value}"
        fused_model_script_layer_names = self._get_fused_model_script_layer_names()
        if fused_model_script_layer_names:
            description += f", FusedModelScriptLayers=[{','.join(fused_model_script_layer_names)}]"
        if hasattr(self, "groups"):
            description += get_groups_short_description(self)
        if self.bn_enabled:
            description += " +BN"
        if hasattr(self, "activation"):
            description += get_act_short_description(self)
        if hasattr(self, "_engine") and self._engine and self._engine != PostprocessTarget.NN_CORE:
            description += f", Engine={self._engine.value}"
        return description

    def _get_fused_model_script_layer_names(self):
        fused_model_script_layer_names = []
        if self._original_names:
            for original_name in self._original_names:
                if original_name.startswith(MODEL_SCRIPT_LAYER_PREFIX):
                    model_script_layer_name = original_name[len(MODEL_SCRIPT_LAYER_PREFIX) + 1 :]
                    if model_script_layer_name != self.name_without_scope:
                        fused_model_script_layer_names.append(model_script_layer_name)

        return fused_model_script_layer_names

    # To be used by nx.algorithms.dag.lexicographical_topological_sort
    def __lt__(self, other):
        if INDEX_NOT_SET in (self.index, other.index):
            return self._insertion_order < other._insertion_order
        return self.index < other.index

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return self._hash

    def __deepcopy__(self, memo):
        new_layer = type(self)()
        self._logger.debug("Copying Layer...")
        for attr, value in vars(self).items():
            if attr == "_logger":
                new_layer._logger = default_logger()
            else:
                setattr(new_layer, attr, copy.deepcopy(value))
        self._logger.debug("Layer was copied.")
        return new_layer

    def is_real_layer(self):
        """
        "Real layer" is a layer that the user thinks of as an actual layer in her network, not
        including inputs, outputs, Defusing etc. This is used to estimate how much control overhead
        our hardware adds over the "ideal hardware".
        """
        return type(self)._IS_REAL_LAYER

    @property
    def short_description(self):
        description = self.name
        if hasattr(self, "kernel_short_description"):
            description += self.kernel_short_description
        if hasattr(self, "groups"):
            description += get_groups_short_description(self)
        if self.bn_enabled:
            description += " +BN"
        if hasattr(self, "activation"):
            description += get_act_short_description(self)
        if hasattr(self, "_engine") and self._engine and self._engine != PostprocessTarget.NN_CORE:
            description += f" +{self._engine.value}"
        fused_model_script_layer_names = self._get_fused_model_script_layer_names()
        for fused_layer_name in fused_model_script_layer_names:
            description += f" +{fused_layer_name}"

        return description

    @property
    def insertion_order(self):
        # currently, identical to hash, but defined separately in purpose, because insertion
        # order should be implemented this way even if hash implementation will change in the
        # future
        return self._insertion_order

    @property
    def requires_native_weights(self):
        if type(self)._REQUIRES_NATIVE_WEIGHTS is None:
            self._logger.warning(
                f"Layer {self.name} of type {self.op.value} does not specify whether native weights are required. Assuming "
                "False.",
            )
            return False

        return type(self)._REQUIRES_NATIVE_WEIGHTS

    @property
    def requires_quantized_weights(self):
        if type(self)._REQUIRES_QUANTIZED_WEIGHTS is None:
            self._logger.warning(
                f"Layer {self.name} of type {self.op.value} does not specify whether quantized weights"
                " are required. Assuming False.",
            )
            return False

        return type(self)._REQUIRES_QUANTIZED_WEIGHTS

    def to_hn(self, should_get_default_params=False):
        self._verify_exportable()
        result = OrderedDict()
        result["type"] = self._op.value
        result["input"] = None
        result["output"] = None
        result["input_shapes"] = self._input_shapes
        if self._output_shapes:
            result["output_shapes"] = self._output_shapes
        if self._original_names:
            result["original_names"] = self._original_names
        if self._is_artificial:
            result["is_artificial"] = self._is_artificial

        self._compilation_params.to_hn(result, should_get_default_params)
        # TODO: do we want to include translation_config? probably not...
        result["quantization_params"] = self.precision_config.raw_dict()
        self._defuse_params.to_hn(result, should_get_default_params)
        if self._engine and self._engine != PostprocessTarget.NN_CORE:
            result["engine"] = self._engine.value
        if not self.in_emulation_graph:
            result["in_emulation_graph"] = self.in_emulation_graph
        if self._base_layer:
            result["base_layer"] = self._base_layer

        return result

    def _serialize_shapes_pb(self, shapes_to_serialize, shapes_pb):
        for shape in shapes_to_serialize:
            shape_pb = shapes_pb.add()
            if len(shape) == 2:
                _, shape_pb.width = shape
            else:
                _, shape_pb.height, shape_pb.width, shape_pb.features = shape

    def to_pb(self, pb_wrapper, is_multi_scope):
        self._verify_exportable()
        node = pb_wrapper.integrated_hw_graph_base_pb2.ProtoNetworkNode()
        node.name = self.name if is_multi_scope else self.name_without_scope
        node.index = self.index
        node.is_artificial = self._is_artificial
        node.input_indices.extend(self.input_indices)
        node.successors_indices.extend(self.output_indices)
        node.original_names.extend(self._prepare_original_names())
        node.engine = pb_wrapper.ENGINE_TYPE_TO_PB[self._engine if self._engine else PostprocessTarget.NN_CORE]
        self._serialize_shapes_pb(self.input_shapes, node.input_shapes)
        self._serialize_shapes_pb(self.output_shapes, node.output_shapes)

        self._compilation_params.to_pb(node, pb_wrapper)
        quant_params = QuantizationParams()
        quant_params.set(self.precision_config.raw_dict())
        quant_params.to_pb(node, pb_wrapper)
        self._defuse_params.to_pb(node, pb_wrapper)

        return node

    def _prepare_original_names(self):
        if len(self._original_names) > self.ORIGINAL_NAME_LIST_LIMIT:
            self._logger.debug(
                f"Original names list for layer {self.name} is too long "
                f"(>{self.ORIGINAL_NAME_LIST_LIMIT}), Trimming it to {self.ORIGINAL_NAME_LIST_LIMIT}",
            )
        original_names = self._original_names[: self.ORIGINAL_NAME_LIST_LIMIT]
        new_original_names = []
        for name in original_names:
            if len(name) > self.ORIGINAL_NAME_LIMIT:
                self._logger.debug(
                    f'Original name for layer {self.name} is too long (>{self.ORIGINAL_NAME_LIMIT}): "{name}", Trimming it to {self.ORIGINAL_NAME_LIMIT}',
                )
            new_original_names.append(name[: self.ORIGINAL_NAME_LIMIT])
        return new_original_names

    def get_compilation_params(self, key=None):
        return self._compilation_params.get(key)

    def set_compilation_params(self, **kws):
        self._compilation_params.set(kws)

    def set_default_compilation_params(self):
        self._compilation_params.set_default_params()

    def update_compilation_params(self, **kws):
        params_to_update = self._compilation_params.get()
        params_to_update.update(kws)
        self._compilation_params.set(params_to_update)

    def get_defuse_params(self, key=None):
        return self._defuse_params.get(key)

    def set_defuse_params(self, **kws):
        self._defuse_params.set(kws)

    def set_default_defuse_params(self):
        self._defuse_params.set_default_params()

    def update_defuse_params(self, **kws):
        params_to_update = self._defuse_params.get()
        params_to_update.update(kws)
        self._defuse_params.set(params_to_update)

    def sort_inputs(self):
        # If the layer indices are initialized we can do a smarter sort
        if self.input_indices and self.index != INDEX_NOT_SET:
            return (
                lambda layer1, layer2: 1
                if self.input_indices.index(layer1.index) > self.input_indices.index(layer2.index)
                else -1
            )
        else:
            return self._default_sort()

    def sort_outputs(self):
        # if the outputs order doesn't really matter, it should still be deterministic
        return self._default_sort()

    def add_input_by_vertex(self, input_layer, input_name=None, input_vertex=None):
        sorted_index = self._find_input_vertex_index(input_name=input_name, input_vertex=input_vertex)
        self._inputs_by_vertex_order[sorted_index] = input_layer

    def _find_input_vertex_index(self, input_name=None, input_vertex=None):
        for index, value in enumerate(self._input_vertex_order):
            key = valid_orig_name(input_vertex.name) if input_vertex else valid_orig_name(input_name)
            if key in (value.split(":")[0], value) or f"{key}:{index}" == value:
                return index

            # edge cases from keras2onnx producer framework
            if input_vertex and hasattr(input_vertex, "output"):
                for out in input_vertex.output:
                    fallback_key = f"{key}:{out}"
                    if fallback_key in self._input_vertex_order:
                        return self._input_vertex_order.index(fallback_key)

                    fallback_key2 = f"{input_vertex.name}:{key}:{out}"
                    if fallback_key2 in self._input_vertex_order:
                        return self._input_vertex_order.index(fallback_key2)

        raise UnsupportedModelError(f"Cannot find input vertex {input_vertex} at {self.full_name_msg}")

    @property
    def inputs_by_vertex_order(self):
        """
        Returns: The inputs of the layer sorted by the order of inputs of the layer's original vertex.
        Some of the vertex's inputs in the vertex graph can be variable nodes and not actual layers, so we can't blindly
        use the vertex inputs. Only actual layers were added to the dict, and by saving them by index we preserve their
        order.
        """
        items_sorted_by_index = sorted(self._inputs_by_vertex_order.items(), key=lambda item: item[0])
        return [inp for key, inp in items_sorted_by_index]

    @property
    def input_vertex_order(self):
        return self._input_vertex_order

    @input_vertex_order.setter
    def input_vertex_order(self, input_order):
        # orig name might contain unsupported chars, like ';', replace with default
        self._input_vertex_order = [valid_orig_name(inp) for inp in input_order]

    def add_original_name(self, name, reverse_insertion=False):
        if name:
            # orig name might contain unsupported chars, like ';', replaced with default
            valid_name = valid_orig_name(name)
            if valid_name not in self._original_names:
                if reverse_insertion:
                    self._original_names.insert(0, valid_name)
                else:
                    self._original_names.append(valid_name)

    def _update_original_names(self, original_names, append=True):
        # orig name might contain unsupported chars, like ';', replace with default
        names_to_add = [valid_orig_name(name) for name in original_names]
        new_original_names = self._original_names + names_to_add if append else names_to_add + self._original_names
        # Add new original without duplicates
        new_original_names_seen = set()
        self._original_names = [
            x for x in new_original_names if not (x in new_original_names_seen or new_original_names_seen.add(x))
        ]

    def _parse_shapes_from_pb(self, shapes_pb):
        parsed_shapes = []
        for shape_pb in shapes_pb:
            if shape_pb.HasField("height") and shape_pb.HasField("features"):
                parsed_shapes.append([-1, shape_pb.height, shape_pb.width, shape_pb.features])
            elif self.op in [LayerType.external_input_layer, LayerType.external_output_layer] and shape_pb.HasField(
                "features",
            ):
                parsed_shapes.append([-1, shape_pb.features])
            else:
                parsed_shapes.append([-1, shape_pb.width])
        return parsed_shapes

    @staticmethod
    def _default_sort():
        return lambda layer1, layer2: 1 if layer1 > layer2 else -1

    def _verify_exportable(self):
        if self.index == INDEX_NOT_SET or self.name is None or self.input_shapes is None or self.output_shapes is None:
            raise ProtobufExportError(f"{self} Missing layer index, name, input_shape or output_shapes")

    def move_params(self, layer):
        self.block_info = layer.block_info
        if layer.original_names:
            self._update_original_names(layer.original_names, append=layer.op not in [LayerType.external_pad])

    @property
    def output_copies(self):
        return self._output_copies

    @output_copies.setter
    def output_copies(self, copies):
        self._output_copies = copies if copies != 0 else 1

    def _calc_output_shape(self):
        return self.input_shape

    def _check_valid_shape(self, shape):
        if any(not isinstance(dim, integer_types) for dim in shape):
            raise UnsupportedModelError(
                f"Unexpected dimension in shape {shape} at {self.full_name_msg}. "
                f"Dimension must be of type 'int' or 'long'",
            )
        if any(dim == 0 for dim in shape):
            raise UnsupportedModelError(f"Unexpected zero dimension in shape {shape} at {self.full_name_msg}")

    def update_output_shapes(self, **kwargs):
        output_shape = self._calc_output_shape()
        self.output_shapes = [output_shape[:] for _ in range(self.output_copies)]

    def _append_output_shape(self, output_shape):
        self._check_valid_shape(output_shape)
        self.append_output_shape(output_shape)

    def pred_layer_output_shape(self, input_layer, validate=True):
        return input_layer._get_output_shape(layer_name=self.name, layer_index=self.index, validate=validate)

    @property
    def output_shapes(self):
        return self._output_shapes

    @output_shapes.setter
    def output_shapes(self, output_shapes):
        if not isinstance(output_shapes, list):
            raise UnsupportedModelError(
                f"Unexpected output_shapes at {self.full_name_msg}, "
                f"output_shapes={output_shapes}, type={type(output_shapes)}",
            )
        if output_shapes and all(isinstance(dim, integer_types) for dim in output_shapes):
            output_shapes = [output_shapes]
        elif any(not isinstance(shape, list) for shape in output_shapes):
            raise UnsupportedModelError(
                f"Unexpected output_shapes at {self.full_name_msg}, "
                f"output_shapes={output_shapes} type={type(output_shapes)}",
            )
        self._output_shapes = []
        for shape in output_shapes:
            self._append_output_shape(shape)

    @property
    def output_shape(self):
        return self._get_output_shape(validate=True)

    def _get_output_shape(self, validate=True, layer_name=None, layer_index=None):
        if len(self._output_shapes) == 0:
            return None

        if any(self._output_shapes[0] != shape for shape in self._output_shapes):
            raise UnsupportedModelError(
                f"Trying to access the output_shape property of a layer with distinct "
                f"output_shapes isn't supported. Use the output_shapes property at "
                f"{self.full_name_msg}, output_shapes={self._output_shapes}",
            )

        return self._output_shapes[0]

    def reshape_input(self, input_shape):
        # The default is no reshapes
        return input_shape

    def _append_input_shape(self, input_shape):
        if len(input_shape) == 2:
            input_shape = [input_shape[0], 1, 1, input_shape[1]]

        self._check_valid_shape(input_shape)
        input_shape = self.reshape_input(input_shape)
        self._input_shapes.append(input_shape)

    @property
    def input_shapes(self):
        return self._input_shapes

    @input_shapes.setter
    def input_shapes(self, input_shapes):
        if not isinstance(input_shapes, list):
            raise UnsupportedModelError(
                f"Unexpected input_shapes at {self.full_name_msg}, input_shapes={input_shapes} "
                f"(type={type(input_shapes)})",
            )
        if input_shapes and all(isinstance(dim, integer_types) for dim in input_shapes):
            input_shapes = [input_shapes]
        elif any(not isinstance(shape, list) for shape in input_shapes):
            raise UnsupportedModelError(
                f"Unexpected input_shapes at {self.full_name_msg}, input_shapes={input_shapes} "
                f"(type={type(input_shapes)})",
            )
        self.set_input_shapes(input_shapes)

    def set_input_shapes(self, input_shapes, validate=True):
        if validate:
            if not isinstance(input_shapes, list) and any(not isinstance(shape, list) for shape in input_shapes):
                raise UnsupportedModelError(
                    f"Unexpected input_shapes at {self.full_name_msg}, "
                    f"input_shapes={input_shapes} (type={type(input_shapes)})",
                )
            if self.number_of_inputs_supported and len(input_shapes) > self.number_of_inputs_supported:
                raise UnsupportedModelError(
                    f"Unexpected input_shapes at {self.full_name_msg}, "
                    f"input_shapes={input_shapes} (type={type(input_shapes)})",
                )

        self._input_shapes = []
        for shape in input_shapes:
            shape[0] = -1
            self._append_input_shape(shape)

    @property
    def input_shape(self):
        if len(self._input_shapes) == 0:
            return None

        if any(self._input_shapes[0] != shape for shape in self._input_shapes):
            raise UnsupportedModelError(
                f"Trying to access the input_shape property of a layer with distinct "
                f"input_shapes isn't supported. Use the input_shapes property at "
                f"{self.full_name_msg}, input_shapes={self._input_shapes}",
            )

        return self._input_shapes[0]

    def _validate_all_dims_are_equal(self, shapes, dim):
        if any(shapes[0][dim] != shape[dim] for shape in shapes):
            raise UnsupportedModelError(f"Shapes at dim #{dim} aren't equal at {self.full_name_msg}, shapes={shapes}")

    def _get_shape_single_dim(self, shapes, dim, validate=True):
        if validate:
            self._validate_all_dims_are_equal(shapes, dim)
        return shapes[0][dim]

    @property
    def name(self):
        return self._name

    @property
    def full_name(self):
        return self._name

    @name.setter
    def name(self, name):
        self._name = name
        self._compilation_params.set_layer_name(name)
        self._defuse_params.set_layer_name(name)
        name_parts = name.split("/", 1)
        if len(name_parts) == 2:
            self._scope, self._name_without_scope = name_parts
        else:
            self._name_without_scope = name_parts[-1]

    @input_shape.setter
    def input_shape(self, input_shape):
        self._input_shapes = []
        self._append_input_shape(input_shape)

    @property
    def input_height(self):
        if len(self._input_shapes[0]) == 2:
            return 1
        return self._get_shape_single_dim(self._input_shapes, 1)

    @property
    def input_width(self):
        if len(self._input_shapes[0]) == 2:
            return 1
        return self._get_shape_single_dim(self._input_shapes, 2)

    @property
    def input_features(self):
        return self._get_shape_single_dim(self._input_shapes, -1)

    @property
    def output_height(self):
        if len(self._output_shapes[0]) == 2:
            return 1
        return self._get_shape_single_dim(self._output_shapes, 1)

    @property
    def output_width(self):
        if len(self._output_shapes[0]) == 2:
            return 1
        return self._get_shape_single_dim(self._output_shapes, 2)

    @property
    def output_features(self):
        return self._get_shape_single_dim(self._output_shapes, -1)

    @property
    def macs(self):
        return 0

    @property
    def ops(self):
        return 0

    @property
    def weights(self):
        return 0

    @property
    def inputs(self):
        return self._input_layers

    @inputs.setter
    def inputs(self, inputs):
        self._input_layers = inputs

    @property
    def outputs(self):
        return self._output_layers

    @outputs.setter
    def outputs(self, outputs):
        self._output_layers = outputs

    @property
    def input_indices(self):
        return self._input_indices

    @input_indices.setter
    def input_indices(self, indices):
        self._input_indices = indices

    @property
    def output_indices(self):
        return self._output_indices

    @output_indices.setter
    def output_indices(self, indices):
        self._output_indices = indices

    @property
    def op(self):
        return self._op

    @op.setter
    def op(self, op):
        self._op = op

    @property
    def fused_indices(self):
        return self._fused_indices

    @fused_indices.setter
    def fused_indices(self, fused_indices):
        self._fused_indices = fused_indices

    @property
    def engine(self):
        return self._engine

    @engine.setter
    def engine(self, engine):
        self._engine = engine

    @property
    def in_emulation_graph(self):
        return self._in_emulation_graph

    @in_emulation_graph.setter
    def in_emulation_graph(self, in_emulation_graph):
        self._in_emulation_graph = in_emulation_graph

    def add_fused_index(self, index):
        self._fused_indices.append(index)

    @property
    def ew_add_enabled(self):
        return False

    @property
    def bn_enabled(self):
        return False

    @property
    def original_names(self):
        return self._original_names

    @original_names.setter
    def original_names(self, original_names):
        self._original_names = original_names

    @property
    def is_artificial(self):
        return self._is_artificial

    @property
    def compilation_params(self):
        return self._compilation_params.get()

    @property
    def defuse_params(self):
        return self._defuse_params.get()

    @property
    def defuse_type(self):
        return self._defuse_params.get("defuse_type")

    @property
    def defuse_types(self):
        return self._defuse_params.get("defuse_types")

    @property
    def defuse_features(self):
        return self._defuse_params.get("defuse_features")

    @property
    def defuse_input_width(self):
        return self._defuse_params.get("defuse_input_width")

    @property
    def defuse_features_offset(self):
        return self._defuse_params.get("defuse_features_offset")

    @property
    def defuse_width_offset(self):
        return self._defuse_params.get("defuse_width_offset")

    @property
    def defuse_ew_add_input_width(self):
        return self._defuse_params.get("defuse_ew_add_input_width")

    @property
    def defuse_ew_add_width_offset(self):
        return self._defuse_params.get("defuse_ew_add_width_offset")

    @property
    def defuse_input_shapes(self):
        return self._defuse_params.get("defuse_input_shapes")

    @property
    def defuse_output_shapes(self):
        return self._defuse_params.get("defuse_output_shapes")

    @property
    def feature_split(self):
        return self._defuse_params.get("feature_split")

    @property
    def number_of_inputs_supported(self):
        return self._number_of_inputs_supported

    def append_input_index(self, index):
        self._input_indices.append(index)

    def append_input_layer(self, input_layer_name):
        self._input_layers.append(input_layer_name)

    def append_input_shape(self, input_shape):
        self._input_shapes.append(input_shape)

    def append_input_shapes(self, input_shapes):
        for shape in input_shapes:
            self._input_shapes.append(shape)

    def append_output_layer(self, output_layer_name):
        self._output_layers.append(output_layer_name)

    def append_output_index(self, output_index):
        self._output_indices.append(output_index)

    def append_output_shape(self, output_shape):
        if len(output_shape) == 2:
            output_shape = [output_shape[0], 1, 1, output_shape[1]]
        self._output_shapes.append(output_shape)

    def replace_input_layer(self, old_name, new_name):
        if old_name in self._input_layers:
            idx = self._input_layers.index(old_name)
            self._input_layers[idx] = new_name

    def replace_input_index(self, old_index, new_index):
        if old_index in self._input_indices:
            idx = self._input_indices.index(old_index)
            self._input_indices[idx] = new_index

    def replace_input_shape(self, old_name, new_input_shape):
        if old_name in self._input_layers:
            idx = self._input_layers.index(old_name)
            self._input_shapes[idx] = new_input_shape

    def replace_output_layer(self, old_name, new_name):
        if old_name in self._output_layers:
            idx = self._output_layers.index(old_name)
            self._output_layers[idx] = new_name

    def replace_output_index(self, old_index, new_index):
        if old_index in self._output_indices:
            idx = self._output_indices.index(old_index)
            self._output_indices[idx] = new_index

    def replace_output_shape(self, old_name, new_output_shape):
        if old_name in self._output_layers:
            idx = self._output_layers.index(old_name)
            self._output_shapes[idx] = new_output_shape

    @classmethod
    def from_hn(cls, hn):
        layer = cls()
        layer.name = hn["name"]
        if "input_shapes" in hn:
            layer.input_shapes = hn["input_shapes"]
        elif "input_shape" in hn:
            layer.input_shape = hn["input_shape"]
        else:
            raise UnsupportedModelError(f"Layer '{hn['name']}' is missing 'input_shape'/'input_shapes' field")
        if "output_shapes" in hn:
            layer._output_shapes = hn["output_shapes"]
        if "original_names" in hn:
            layer._update_original_names(hn["original_names"])
        if "is_artificial" in hn:
            layer._is_artificial = hn["is_artificial"]
        if "output" in hn:
            layer._output_layers = hn["output"]
        if "input" in hn:
            layer._input_layers = hn["input"]
        if "engine" in hn:
            layer._engine = PostprocessTarget(hn["engine"])
        if "in_emulation_graph" in hn:
            layer.in_emulation_graph = hn["in_emulation_graph"]
        if "base_layer" in hn:
            layer._base_layer = hn["base_layer"]

        return layer

    def set_legacy_precision_config(self, legacy_config):
        """
        The purpose of this logic is to load the deprecated quantization params from the hn.
        A user should never reach this code, but some of our test set the bias mode / precision mode
        using the old keys. Those test should be fixed, and for the meanwhile I raise deprecation warning
        """
        if not legacy_config:
            return
        if QuantizationDeprecatedParam.use_16bit_bias in legacy_config:
            self._logger.deprecation_warning(
                f"quantization_params key {QuantizationDeprecatedParam.use_16bit_bias} "
                f"will be deprecated in the future. "
                f"Please use bias_mode instead",
            )
            use_16bit_bias = legacy_config[QuantizationDeprecatedParam.use_16bit_bias]
            if use_16bit_bias is True or use_16bit_bias == "enabled":
                bias_mode_val = BiasMode.double_scale_initialization
            else:
                bias_mode_val = BiasMode.single_scale_decomposition
            if self.precision_config.bias_mode is None:
                self.precision_config.bias_mode = bias_mode_val
            else:
                raise HailoNNLayerParamsException(
                    f"Layer {self.name} has both bias_mode and use_16bit_bias quantization_params. "
                    f"The fields are mutually exclusive, please use only bias_mode.",
                )

        if QuantizationDeprecatedParam.use_4bit_weights in legacy_config:
            self._logger.deprecation_warning(
                f"quantization_params key {QuantizationDeprecatedParam.use_4bit_weights} "
                f"will be deprecated in the future. "
                f"Please use precision_mode instead",
            )
            use_4bit_weights = legacy_config[QuantizationDeprecatedParam.use_4bit_weights]
            if use_4bit_weights is True or use_4bit_weights == "enabled":
                precision_mode_val = PrecisionMode.a8_w4
            else:
                precision_mode_val = PrecisionMode.a8_w8
            if self.precision_config.precision_mode is None:
                self.precision_config.precision_mode = precision_mode_val
            else:
                raise HailoNNLayerParamsException(
                    f"Layer {self.name} has both precision_mode and use_4bit_weights quantization_params. "
                    f"The fields are mutually exclusive, please use only precision_mode.",
                )

        if QuantizationDeprecatedParam.exponential_mode_4bit_weights in legacy_config:
            self._logger.deprecation_warning(
                f"quantization_params key {QuantizationDeprecatedParam.exponential_mode_4bit_weights} "
                f"will be deprecated in the future. "
                f"Please use precision_mode instead",
            )
            exp_weights = legacy_config[QuantizationDeprecatedParam.exponential_mode_4bit_weights]
            if exp_weights is True or exp_weights == "enabled":
                precision_mode_val = PrecisionMode.a8_w4_exp
            else:
                precision_mode_val = PrecisionMode.a8_w8
            if self.precision_config.precision_mode is None:
                self.precision_config.precision_mode = precision_mode_val
            else:
                raise HailoNNLayerParamsException(
                    f"Layer {self.name} has both precision_mode and use_4bit_weights quantization_params. "
                    f"The fields are mutually exclusive, please use only precision_mode.",
                )

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = cls()
        layer.name = pb.name
        layer.index = pb.index
        layer._is_artificial = pb.is_artificial
        layer.original_names.extend(pb.original_names)

        if pb.input_shapes:
            layer.input_shapes = layer._parse_shapes_from_pb(pb.input_shapes)
        if pb.output_shapes:
            layer.output_shapes = layer._parse_shapes_from_pb(pb.output_shapes)

        layer._compilation_params.from_pb(pb, pb_wrapper)
        quant_params = QuantizationParams()
        quant_params.from_pb(pb, pb_wrapper)
        layer._precision_config = quant_params.to_precision_config()
        layer._defuse_params.from_pb(pb, pb_wrapper)

        layer.input_indices.extend(pb.input_indices)
        layer.output_indices.extend(pb.successors_indices)
        layer.engine = pb_wrapper.ENGINE_PB_TO_TYPE.get(pb.engine, PostprocessTarget.NN_CORE)
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = cls()
        layer.name = old_layer.name
        layer.index = old_layer.index
        layer._is_artificial = old_layer.is_artificial
        layer.original_names.extend(old_layer.original_names)
        layer.input_shapes.extend(old_layer.input_shapes)
        layer.output_shapes.extend(old_layer.output_shapes)
        layer.input_vertex_order = old_layer.input_vertex_order
        layer.dynamic_weights = old_layer.dynamic_weights
        layer._compilation_params.override_params_from_kwargs(**old_layer.compilation_params)
        layer._precision_config = LayerPrecisionConfig(**old_layer.precision_config.raw_dict())
        layer._translation_config = LayerTranslationConfig(**old_layer.translation_config.raw_dict())
        layer._defuse_params.override_params_from_kwargs(**old_layer.defuse_params)
        layer.input_indices.extend(old_layer.input_indices)
        layer.inputs.extend(old_layer.inputs)
        layer.output_indices.extend(old_layer.output_indices)
        layer.outputs.extend(old_layer.outputs)
        layer.fused_indices.extend(old_layer.fused_indices)
        layer.output_copies = old_layer.output_copies
        layer.engine = old_layer.engine
        layer.block_info = old_layer.block_info
        return layer

    @property
    def defuse_name(self):
        return self._defuse_params.get("defuse_name")

    @defuse_name.setter
    def defuse_name(self, name):
        self._defuse_params.set_defuse_name(name)

    @property
    def defuse_name_without_scope(self):
        defuse_name = self.defuse_name
        return defuse_name.split("/")[-1] if defuse_name else defuse_name

    def is_defused(self):
        return self.defuse_name != "" and self.defuse_type is not DefuseType.none

    def _validate_no_zeros_in_shape(self, shape):
        if any(dim == 0 for dim in shape):
            raise UnsupportedModelError(f"Unexpected zero dimension in shape {shape} in layer {self.full_name_msg}")

    @property
    def name_without_scope(self):
        return self._name_without_scope

    @abstractmethod
    def get_equalization_handler_type(self, predecessor=None):
        pass

    @abstractmethod
    def get_params_sorter_handler_type(self, predecessor=None):
        pass

    @abstractmethod
    def get_dead_channels_removal_handler_type(self, predecessor=None):
        pass

    @abstractmethod
    def ibc_supported(self):
        pass

    def get_quarot_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def finetune_supported(self):
        return True

    def get_algo_callback(self, algo):
        classifier_callbacks_by_algo = {
            QuantizationAlgorithms.equalization: self.get_equalization_handler_type,
            QuantizationAlgorithms.params_sorter: self.get_params_sorter_handler_type,
            QuantizationAlgorithms.dead_channels_removal: self.get_dead_channels_removal_handler_type,
            QuantizationAlgorithms.ibc: self.ibc_supported,
            QuantizationAlgorithms.quarot: self.get_quarot_handler_type,
        }
        return classifier_callbacks_by_algo[algo]

    @property
    def scope(self):
        return self._scope

    @property
    def dynamic_weights(self):
        return self._dynamic_weights

    @dynamic_weights.setter
    def dynamic_weights(self, dynamic_weights):
        self._dynamic_weights = dynamic_weights

    @property
    def hn_name(self):
        TYPE_TO_HN_NAME2 = {
            LayerType.dense: "fc",
            LayerType.bias_add: "bias_add",
        }
        return TYPE_TO_HN_NAME2.get(self.op, self.op.value)

    @property
    def transposed(self):
        return self._transposed

    @transposed.setter
    def transposed(self, transposed):
        self._transposed = transposed

    @property
    def full_name_msg(self):
        msg = f"{self.op.value.replace('_', ' ')} layer"
        if self.name:
            msg += f" {self.name}"
        if self.original_names:
            msg += f" (translated from {self.original_names[-1]})"

        return msg

    @property
    def group_sizes(self):
        if self._group_sizes is None and hasattr(self, "groups"):
            return [1] * self.groups
        return self._group_sizes

    @group_sizes.setter
    def group_sizes(self, group_sizes):
        self._group_sizes = group_sizes

    def is_zippable(self, other):
        """
        Allow zipping two layers as long as they share the parameters except output features
        Ratio between output features should be consistent along the zipped network - this is checked in the zipper
        """
        if self.op != other.op:
            return False
        if len(self.input_shapes) != len(other.input_shapes):
            return False
        if any(shape1[:-1] != shape2[:-1] for shape1, shape2 in zip(self.input_shapes, other.input_shapes)):
            return False
        if any(self.output_shapes[0] != shape for shape in self.output_shapes):
            return False
        if any(other.output_shapes[0] != shape for shape in other.output_shapes):
            return False
        if self.output_shape[:-1] != other.output_shape[:-1]:
            return False
        if self.dynamic_weights != other.dynamic_weights:
            return False
        if self.engine != other.engine:
            return False
        return True

    def _is_allowed_from_dense_input(self, input_shape):
        # In the case of concat from dense, it's predecessor shape must only have features (no width nor height)
        return len(input_shape) == 2 or input_shape[1] == input_shape[2] == 1

    def is_from_dense(self, validate=True):
        # SDK-9087- remove after in_format_type and out_format_type implemented
        is_any_flat_predecessor = any(
            len(self.pred_layer_output_shape(pred, validate)) == 2 for pred in self._input_list
        )
        if is_any_flat_predecessor:
            for layer in self._input_list:
                input_shape = self.pred_layer_output_shape(layer, validate)
                if not self._is_allowed_from_dense_input(input_shape):
                    raise UnsupportedModelError(
                        f"Unsupported mix of input layers with different output format shape at {self.full_name_msg}",
                    )
        return is_any_flat_predecessor

    @property
    def block_info(self):
        return self._block_info

    @block_info.setter
    def block_info(self, details):
        self._block_info = details

    @property
    def base_layer(self):
        return self._base_layer

    @base_layer.setter
    def base_layer(self, base_layer):
        self._base_layer = base_layer
