from typing import Dict

import networkx as nx
import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.nms_op import NMSOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_fused_bbox_decoder import HailoFusedBboxDecoder
from hailo_model_optimization.acceleras.hailo_layers.hailo_nms import HailoNMS
from hailo_model_optimization.acceleras.hailo_layers.hailo_proposal_generator import HailoProposalGenerator
from hailo_model_optimization.acceleras.model.hailo_model.equiv_flow import EquivFlow
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EncodingMatchType,
    FeaturePolicy,
    IOVectorPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AlgoErrorHint,
    MatchingAlgoError,
)
from hailo_model_optimization.acceleras.utils.opt_utils import limvals_to_zp_scale
from hailo_model_optimization.acceleras.utils.to_qnpz_utils import qp_to_limvals
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class MatchingAlgo(OptimizationAlgorithm):
    """
    Matching Algo is an algorithm that given a HailoModel will construct and EquivFlow from the
    flow graph of the model. Matching Algo will iterate on the matching_components of the EquivFlow
    (each of which is a set of edges of the original graph)
    and  decide what type of matching we should do, and propagate the scales/zp to the rest of the component.



    the EquivFlow.get_toposorted_components returns a list of matching components. (see documentation there - LINK)
    We will highlight here the important concepts:
        1. matching_components - is subgraph of the edges graph induced on specific edges that share the same
            scales/zp. of edges of the flow_graph
        2. source_layers : self.equiv_match.source_layers(matching_component) - the list of layers that
        3. consumer_layers : self.equiv_match.consumer_layers(matching_component)

    Args:
        model : HailoModel
        model_config : ModelConfig Object with configuration.
        logger : logger

    Attributes:
        equiv_match : EquivFlow
        scalar_scale_map : defines each layer if it inputs scale must be scalar (the same value for all channels)
        scalar_zp_map :  defines each layer if it inputs zp must be scalar (the same value for all channels)

    """

    equiv_match: EquivFlow
    scalar_scale_map: Dict[str, bool]
    scalar_zp_map: Dict[str, bool]

    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="Scale Matching", logger_level=logger_level, logger=logger)

    def must_detection_head_matching_group(self, matching_component_group):
        """
        If we found that the consumer is NMS, we need to do a special scale matching for the detection head.
        """
        for matching_component in matching_component_group:
            is_detection_head = np.any(
                [
                    isinstance(self._model.layers[cons], HailoNMS)
                    for src, cons in nx.lexicographical_topological_sort(matching_component)
                ],
            )

            if is_detection_head:
                return True

        return False

    def must_scale_matching_group(self, matching_component_group):
        for matching_component in matching_component_group:
            if self._must_scale_matching(matching_component):
                return True
        return False

    def _has_concat_norm(self, matching_component_group):
        """
        a very hacky function for now to point out concat in layer norm
        """
        if len(matching_component_group) > 1:
            return False
        matching_component = matching_component_group[0]
        if np.any(
            [
                isinstance(self._model.layers[cons], HailoConcat) and "norm" in self._model.layers[cons].full_name
                for src, cons in nx.lexicographical_topological_sort(matching_component)
            ],
        ):
            return True
        return False

    def must_zp_matching_group(self, matching_component_group):
        if self._has_concat_norm(matching_component_group):
            return False
        for matching_component in matching_component_group:
            if self._must_zp_matching(matching_component):
                return True
        return False

    def _must_scale_matching(self, matching_component):
        """
        a function to check if we need to do proper scale matching between the source layers -

        In order to check it we go over the consumer layers of the matching_component and if at least one of them
        is scalar_scale we must ensure the scales are the same.

        Example:
            (conv1, concat_1), (conv2, concat_1)
            (concat_1, softmax)


        Args:
            matching_component: the matching matching_component


        Returns: True if we need and False if not

        """

        if self._has_spatial_concat(matching_component):
            return True
        if len(self.equiv_match.source_layers(matching_component)) == 1:
            return False
        consumer_layers = self.equiv_match.consumer_layers(matching_component)
        return np.any([self.scalar_scale_map[out_layer] for out_layer in consumer_layers])

    def _has_spatial_concat(self, matching_component):
        """
        check if component has spatial concat
        """

        def _is_spatial_concat(layer):
            return isinstance(layer, HailoConcat) and layer.spatial_concat

        return np.any([_is_spatial_concat(self._model.layers[u]) for (u, v) in matching_component.nodes])

    def _must_zp_matching(self, matching_component):
        """
        a function to check if we need to do zp matching
        Args:
            matching_component: equiv component

        Returns: need to do zp matching

        """
        if len(self.equiv_match.source_layers(matching_component)) == 1:
            return False
        consumer_layers = self.equiv_match.consumer_layers(matching_component)
        return np.any([self.scalar_zp_map[out_layer] for out_layer in consumer_layers])

    def zp_matching_group(self, matching_component_group):
        """
        go over all the source layers of groups (which is a list of match_components)
        zp_matching algorithms
            ensures all the zp of source layers are the same
        Args:
            matching_component_group: list[matching_component]
        """
        self._info_matching_component_group(matching_component_group, "zp_matching")
        layers = self.equiv_match.source_layers_group(matching_component_group)
        layers_with_out_ind = self.get_layers_output_index(layers, matching_component_group)
        self._sources_zp_matching_group(layers_with_out_ind)

    def _sources_zp_matching_group(self, layers_outputs):
        zp = self._get_zero_point(layers_outputs)
        for layer_name, out_ind in layers_outputs:
            source_layer = self._model.layers[layer_name]
            zp_old = source_layer.output_zero_points[out_ind]
            scale_vector = source_layer.output_scales[out_ind]

            bits = source_layer.get_output_lossy_elements()[out_ind].bits
            bins = 2**bits - 1
            new_scale = self.get_new_scale_vector(zp_old, zp, scale_vector, bins)

            source_layer.set_output_scale(new_scale, out_ind)
            source_layer.set_output_zero_point(zp, out_ind)

    def detection_head_matching_group(self, matching_component_group):
        """
        Performs scale matching for post-processing of bbox decoder layers. Two algorithms were
        implemented. The first is for non-fused post-processing layers, in which a proposal
        layer is followed by a bbox decoder layer (that is, two seperate layers). The second
        is for a fused post-processing layer that combines (fuses) the bbox decoder and the
        proposal layers into one layer.

        Note that the current implementation supports a non-fused or a fused architecture, but not both
        in the same matching_component element.
        """
        detection_head_layers = []
        for matching_component in matching_component_group:
            sources = self.equiv_match.source_layers(matching_component)
            all_fused_bbox_decoder = all(isinstance(self._model.layers[src], HailoFusedBboxDecoder) for src in sources)
            any_fused_bbox_decoder = any(isinstance(self._model.layers[src], HailoFusedBboxDecoder) for src in sources)
            if all_fused_bbox_decoder ^ any_fused_bbox_decoder:  # xor
                AccelerasImplementationError(
                    "detection_head_matching_group supports matching_component with bbox decoder and proposal layers, or fused post-processing"
                    " layers, but not both in the same matching_component.",
                )

            matching_component_head_layers = (
                self.detection_head_matching_group_fused(sources)
                if all_fused_bbox_decoder
                else self.detection_head_matching_group_non_fused(sources)
            )

            detection_head_layers.extend(matching_component_head_layers)

        layers_with_out_ind = self.get_layers_output_index(detection_head_layers, matching_component_group)
        self._sources_scale_matching(layers_with_out_ind)

    def detection_head_matching_group_non_fused(self, sources):
        bbox_layers = []
        score_layers = []

        for lname in sources:
            successor_names = self._model.flow.successors_sorted(lname)
            proposal_layers = []
            for successor_name in successor_names:
                successor_layer = self._model.layers[successor_name]
                if isinstance(successor_layer, HailoProposalGenerator):
                    proposal_layers.append(successor_layer)
            if len(proposal_layers) != 1:
                raise MatchingAlgoError(
                    f"Received a wrong layer in detection head flow- "
                    f"source layer: {lname}, successor layers: {successor_names}"
                    f"expected proposal generator layer",
                )
            proposal_layer = proposal_layers[0]

            proposal_inputs = self._model.flow.predecessors_sorted(proposal_layer.full_name)
            if lname == proposal_inputs[0]:
                bbox_layers.append(lname)
            elif lname == proposal_inputs[1]:
                score_layers.append(lname)
            else:
                raise MatchingAlgoError("Received a wrong layer in detection head sources")

        # Currently we will do scale matching for all NMS inputs together,
        # In the future we want to separate the scales of the bbox and the scores
        detection_head_layers = bbox_layers + score_layers
        return detection_head_layers

    def detection_head_matching_group_fused(self, sources):
        """
        Performs scale matchings for fused_bbox_decoders.

        Args:
            sources (List[str]): a list of layer names (string) of all the layers in one
            matching_component.

        Raises:
            MatchingAlgoError: check that all source elements are of type HailoFusedBboxDecoder.

        """
        fused_detection_head_layers = []

        for lname in sources:
            source_layer = self._model.layers[lname]
            if not isinstance(source_layer, HailoFusedBboxDecoder):
                raise MatchingAlgoError(
                    "This function assumes that all source elements are of type"
                    f"HailoFusedBboxDecoder, but got lname: {source_layer}.",
                )
            fused_detection_head_layers.append(lname)

        return fused_detection_head_layers

    def _get_limvals(self, layers_outputs):
        """
        get the matching scale for the layers
        """
        layers_force_range_scale_dict = {
            layer_name: list(self._model.layers[layer_name].strong_force_range)
            for layer_name, _ in layers_outputs
            if self._model.layers[layer_name].strong_force_range
        }
        return self._handle_generic_cases(self._calc_limvals, layers_outputs, layers_force_range_scale_dict)

    def _get_zero_point(self, layers_outputs):
        """
        get the matching zp for the layers
        """
        layers_force_range_zp_dict = {
            layer_name: self._model.layers[layer_name].output_zero_points[out_ind]
            for layer_name, out_ind in layers_outputs
            if self._model.layers[layer_name].strong_force_range
        }
        return self._handle_generic_cases(self._calc_zp, layers_outputs, layers_force_range_zp_dict)

    def _calc_zp(self, layers_outputs):
        """calc the matched zp for the source layers by taking their mean

        Args:
           layers_outputs (list): a list of tuples of layer_names and output_indexes
        Returns:
            int: the zero point
        """
        zp_vals = [self._model.layers[layer_name].output_zero_points[out_ind] for layer_name, out_ind in layers_outputs]
        return np.ceil(np.mean(zp_vals))

    def _calc_limvals(self, layers_outputs):
        """calc the matched limvals for the source layers by taking their absolot min and max

        Args:
            layers_outputs (list): a list of tuples of layer_names and output_indexes

        Raises:
            MatchingAlgoError: in case there is o quant element rase exeption

        Returns:
           (list) : (min,max) -wanted limvals
        """
        output_scales = []
        zero_points = []
        limvals = []
        quant_elements = []

        for layer_name, out_ind in layers_outputs:
            source_layer = self._model.layers[layer_name]
            output_scale_scalar = source_layer.get_scalar_vector(source_layer.output_scales[out_ind])
            output_zero_point = source_layer.output_zero_points[out_ind]

            qp_out = (output_zero_point, output_scale_scalar)
            quant_element = source_layer.get_output_lossy_elements()[0]
            limvals_out = qp_to_limvals(qp_out, quant_element.bits, is_symertric=quant_element.signed)
            limvals.append(limvals_out)
            output_scales.append(output_scale_scalar)
            zero_points.append(output_zero_point)
            quant_elements.append(quant_element)

        if len(quant_elements) == 0:
            raise MatchingAlgoError("quant element is empty")
        zipped_list = list(zip(*limvals))
        inp_limvals = [np.min(zipped_list[0]), np.max(zipped_list[1])]
        return inp_limvals

    def _handle_generic_cases(self, callback_function, layers_outputs, layers_force_range_dict):
        """list of source layers and check if there is a confliction

        1. no force_range
        2. exeption - confliction in force_range
        3. default behaviour - only one force_range

        Args:
            callback_function (callback_function): a function how to calc the wanted value
            layers_outputs (list): a list of tuples of layer_names and output_indexes
            layers_force_range_dict (dict): a dict layer_names and "value" filterd by the force_range

        Raises:
            AlgoErrorHint:

        Returns:
            int  ir list  return value
        """

        layer_force_range_keys = [layer_name for layer_name in layers_force_range_dict.keys()]
        layer_force_range_values = [forced_value for forced_value in layers_force_range_dict.values()]

        # Case 1 - no force range
        if len(layer_force_range_values) == 0:
            return callback_function(layers_outputs)

        # Case 2 - handle force_range:
        if len(layer_force_range_keys) > 1 and not all(
            value == layer_force_range_values[0] for value in layer_force_range_values
        ):
            solutions = [
                f"quantization_param([{layer_name}], force_range_out={list(self._model.layers[layer_name].strong_force_range)}, weak_force_range_out=enabled)"
                for layer_name, _ in layers_force_range_dict.items()
            ]
            explenations = [
                "this option may ignore the force_range_out in the layers due other constraints",
            ] * len(solutions)
            general_info = (
                f"there are conflicting values of  force_range_out in the following layers {layer_force_range_keys}"
            )
            raise AlgoErrorHint(general_info, solutions, explenations)

        # Case 3 - huristic - only one force_range
        self._logger.verbose(
            f"force_range was set for the following layers: {list(set(key for key, _ in layers_outputs))} given by force range of {layer_force_range_keys}"
        )

        return layer_force_range_values[0]

    def _sources_scale_matching(self, layers_outputs):
        source_layer = self._model.layers[layers_outputs[0][0]]
        quant_element = source_layer.get_output_lossy_elements()[0]
        inp_limvals = self._get_limvals(layers_outputs)
        new_zp, new_scale, _ = limvals_to_zp_scale(inp_limvals, quant_element)

        for layer_name, out_ind in layers_outputs:
            source_layer = self._model.layers[layer_name]
            repeated_scale = np.repeat(new_scale, len(source_layer.output_scales[out_ind]))
            source_layer.set_output_scale(repeated_scale, out_ind)
            source_layer.set_output_zero_point(new_zp, out_ind)

    def scale_matching_group(self, matching_component_group):
        """
        go over all the source layers of groups (which is a list of match_components)
        ensures all the scales of source layers are the same
        Args:
            matching_component_group: list[matching_component]
        """
        self._info_matching_component_group(matching_component_group)
        layers = self.equiv_match.source_layers_group(matching_component_group)
        layers_with_out_ind = self.get_layers_output_index(layers, matching_component_group)
        self._sources_scale_matching(layers_with_out_ind)

    def get_layers_output_index(self, layers, matching_component_group):
        matching_group_nodes = set()
        for matching_comp in matching_component_group:
            matching_group_nodes |= matching_comp.nodes
        layers_with_out_ind = []
        for layer in layers:
            for out_edge in self._model.flow.out_edges(layer):
                if out_edge in matching_group_nodes:
                    out_ind = self._model.flow.get_edge_output_index(*out_edge)
                    out_ind = self._model.layers[layer].resolve_output_index(out_ind)
                    layers_with_out_ind.append((layer, out_ind))
        return layers_with_out_ind

    def _info_matching_component_group(self, matching_component_group, string="scale_matching"):
        for matching_component in matching_component_group:
            a = [x for x in nx.lexicographical_topological_sort(matching_component)]
            self._logger.debug(f"must {string} {a}")

    @classmethod
    def get_new_scale_vector(cls, zp_old, zp, scale_vector, bins):
        # if we are in the case the zp is smaller than the old zp:
        ratio = cls._get_ratio(zp_old, zp, bins)
        return ratio * scale_vector

    @staticmethod
    def _get_ratio(zp_old, zp, bins):
        """
        if we change rmin then we know the following two:
        1. rmax/scale +zp = bins          ==>>   scale = rmax / (bins-zp)
        2. rmax/scale_old +zp_old = bins  ==>>   scale_old = rmax / (bins-zp_old)
        ==>>
        scale = scale_old * (bins-zp_old)/(bins-zp)


        if we change rmax then we know the following two:
        1. rmin/scale +zp = 0          ==>>    scale = rmin / -zp   (as long as zp is not 0)
        2. rmax/scale_old +zp_old = 0  ==>>    scale_old = rmin / -zp_old   (as long as zp_old is not 0)
        ==>>
        scale = scale_old * (bins-zp_old)/(bins-zp)


        """
        if zp == zp_old:
            return 1

        if zp < zp_old:
            if zp == 0:
                raise MatchingAlgoError
            return zp_old / zp
        else:
            return (bins - zp_old) / (bins - zp)

    def enforce_constraints_group(self, matching_component_group, training=False):
        for matching_component in matching_component_group:
            edges_list = list(nx.lexicographical_topological_sort(matching_component))
            self._model.enforce_constraints(edges_list, training=training, create_ratio=True)

    @staticmethod
    def prepare_nms_scales(prev_layer):
        def convert_scales(old_scale, old_bits, new_bits):
            return old_scale * (2**old_bits) / (2**new_bits)

        # Get the new scales
        old_bits = prev_layer.get_output_lossy_elements()[0].bits
        score_scale = convert_scales(prev_layer.output_scale[4], old_bits, NMSOp.SCORES_BITS)
        boxes_scale = convert_scales(prev_layer.output_scale[0], old_bits, NMSOp.BOXES_BITS)

        # Prepare the indexes of boxes and scores
        proposals = int(prev_layer.output_scale.shape[0] / 5)
        scores_idx = np.repeat([False, True], repeats=[4, 1])
        scores_idx = np.tile(scores_idx, [proposals])
        boxes_idx = np.logical_not(scores_idx)

        # Set the scale
        output_scale = np.zeros(shape=prev_layer.output_scale.shape, dtype=np.float32)
        output_scale[scores_idx] = score_scale
        output_scale[boxes_idx] = boxes_scale
        return output_scale

    def should_skip_algo(self):
        pass

    def _get_required_match_type(self, matching_component_group):
        if self.must_detection_head_matching_group(matching_component_group):
            return EncodingMatchType.DETECTION_HEAD_MATCH
        if self.must_scale_matching_group(matching_component_group):
            return EncodingMatchType.SCALE_MATCH
        elif self.must_zp_matching_group(matching_component_group):
            return EncodingMatchType.ZERO_POINT_MATCH
        else:
            return EncodingMatchType.NO_MATCH

    def match_components_group(self, matching_component_group, training=False):
        match_type = self._get_required_match_type(matching_component_group)
        if match_type == EncodingMatchType.DETECTION_HEAD_MATCH:
            self.detection_head_matching_group(matching_component_group)
        elif match_type == EncodingMatchType.SCALE_MATCH:
            self.scale_matching_group(matching_component_group)
        elif match_type == EncodingMatchType.ZERO_POINT_MATCH:
            self.zp_matching_group(matching_component_group)
        else:
            self._info_matching_component_group(matching_component_group, "dont_need_matching")

        # WIP!!
        if training:
            self._set_variables(matching_component_group, match_type)
        self.enforce_constraints_group(matching_component_group, training=training)
        return match_type

    def update_output_scales(self):
        # enforce io encoding on output layers
        for lname in self._model.flow.output_nodes:
            layer = self._model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer.enforce_io_encoding()

    def _run_int(self):
        for matching_component_group in self.equiv_match.get_groups_components():
            self.match_components_group(matching_component_group)

        self.update_output_scales()

    def _setup(self):
        algo_config = self.get_algo_config()
        self.input_encoding_vector = algo_config.input_encoding_vector != FeaturePolicy.disabled
        self.output_encoding_vector = algo_config.output_encoding_vector != FeaturePolicy.disabled
        self._set_io_property()
        self.equiv_match = EquivFlow.from_hailo_model(self._model)
        self.scalar_scale_map, self.scalar_zp_map = self.define_layers()

    def define_layers(self):
        scalar_scale_map = self._property_map(IOVectorPolicy.SCALAR_INPUT_SCALE)
        scalar_zp_map = self._property_map(IOVectorPolicy.SCALAR_INPUT_ZP)

        return scalar_scale_map, scalar_zp_map

    def _set_io_property(self):
        for layer_name in self._model.flow.input_nodes:
            layer = self._model.layers[layer_name]
            layer.force_scalar_encoding = not (self.input_encoding_vector)
        for layer_name in self._model.flow.output_nodes:
            layer = self._model.layers[layer_name]
            layer.force_scalar_encoding = not (self.output_encoding_vector)

    def _property_map(self, property_map):
        """
        creates a dict map that defined for each layer input if it supports vector
        scale/zp based on the property it gets.
        Algorithms:
        1. We start with by updating all layers input/output scalar map:
            1.1 input map - every layer will return False or True based on its properties
            (see _get_input_scalar_property) - # except output layers that are set to be based on: must_be_scalar_output.
            2.2 output map - we set all layers to be False at first (except output layers that are set to be True)
        2. go over all layers in reversed toposort:
            get input property of current layer.
            2.1 go over all predecessors of layer and set output property based on the input property.
            2.2 if the layer is a preserver update its input property. if not it is set to be the default one that was
                set

        Note that every layer we get to in the reverse topological sort will is scalar_output_map is updated by its all
        predecessor.


        Args:
            property_map:IOVectorPolicy.SCALAR_INPUT_SCALE/IOVectorPolicy.SCALAR_INPUT_ZP

        Returns: dict

        """
        scalar_input_map = dict()
        scalar_output_map = dict()
        for layer_name in self._model.flow.toposort():
            layer = self._model.layers[layer_name]
            if layer_name in self._model.flow.output_nodes or layer_name in self._model.flow.input_nodes:
                scalar_output_map[layer_name] = self._is_scalar_encoding(layer, property_map)
                scalar_input_map[layer_name] = self._is_scalar_encoding(layer, property_map)
            else:
                scalar_output_map[layer_name] = self._get_output_scalar_property(
                    layer,
                    property_map,
                )  # set all layers to be False at first
                scalar_input_map[layer_name] = self._get_input_scalar_property(layer, property_map)
        layers_toposort = reversed(list(self._model.flow.toposort()))
        for layer_name in layers_toposort:
            for predecessor_name in self._model.flow.predecessors_sorted(layer_name):
                # based on current layers *input* property update predecessor *output* property -
                # if that output property is already updated get the value of it.
                scalar_output_map[predecessor_name] = (
                    scalar_input_map[layer_name] or scalar_output_map[predecessor_name]
                )
                predecessor = self._model.layers[predecessor_name]

                # if we are here we are partial preservers/preserver we must update of input based on the output
                if self._inter_layers_preserver(predecessor):
                    # here we will override the property that is set if we are an inter_layers_preserver
                    scalar_input_map[predecessor_name] = scalar_output_map[predecessor_name]
        return scalar_input_map

    @classmethod
    def _is_scalar_encoding(cls, layer, property_map):
        """
        set all output layers property from there encoding:
        1. scale - based on  layer.force_scalar_encoding
        2. zp - now is TRUE will be changed in the future to layer.force_scalar_encoding

        """
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return True

        if property_map == IOVectorPolicy.SCALAR_INPUT_SCALE:
            return layer.force_scalar_encoding

        elif property_map == IOVectorPolicy.SCALAR_INPUT_ZP:
            # for now we still dont support zp to be a vector- will be changed in near future
            return True

    @classmethod
    def _get_input_scalar_property(cls, layer, property_map):
        """
        get a layer and return if the input scale/zp should be scalar.

        """
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return True
        consumer_input_scale = layer.consumer_input_scale  # is the layer consumes scales
        if property_map == IOVectorPolicy.SCALAR_INPUT_SCALE:
            # all layers that don't have a kernel must have scalar scale input_scale
            # (unless they are preservers, and then we handle then they may be changed because layers updated them)
            return not consumer_input_scale

        elif property_map == IOVectorPolicy.SCALAR_INPUT_ZP:
            # now all layers must have input_zp but in the future it may change
            return True

        else:
            raise MatchingAlgoError("IOVectorPolicy is not defined")

    @classmethod
    def _get_output_scalar_property(cls, layer, property_map):
        """
        get a layer and return if the output scale/zp should be scalar.

        """
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return True

        if property_map == IOVectorPolicy.SCALAR_INPUT_SCALE:
            return not layer.homogeneous  # is layer homogenius

        elif property_map == IOVectorPolicy.SCALAR_INPUT_ZP:
            # now all layers must have input_zp but in the future it may change
            return True

    @staticmethod
    def _inter_layers_preserver(layer):
        """
        indicates if output needs to update the input property
        """
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return None
        consumer_input_scale = layer.consumer_input_scale
        homogeneous = layer.homogeneous
        if not consumer_input_scale and homogeneous:
            return True
        return False

    def log_config(self):
        pass

    def get_algo_config(self):
        return self._model_config.globals

    def _set_variables(self, comp, match_type: EncodingMatchType = EncodingMatchType.NO_MATCH):
        """
        wip
        Args:
            comp:
            enforce_zp:
            enforce_scale:

        Returns
        TODO - for train scales
        for layer_name in self.source_layers(comp):
            source_layer = self._hailo_model.layers[layer_name]
            if not isinstance(source_layer.output_scale, tf.Variable):er.output_scale.shape,
                #                                                             trainable=True,
                #                                                             initializer=tf.keras.initializers.Constant(
                #                                                                 source_layer.output_scale))
                # if enforce_zp:
                #     variable = source_layer.output_scale
                #     for layer_name in self.source_layers(comp):
                #         source_layer = self._hailo_model.layers[layer_name]
                #         if not isinstance(source_layer.output_scale, tf.Variable):
                #             source_layer.output_scale = source_layer.add_weight(name='output_scale',
                #                                                                 shape=source_layer.output_scale.shape,
                #                                                                 trainable=True,
                #                                                                 initializer=tf.keras.initializers.Constant(
                #                                                                     variable))
                # if enforce_scale:
                #     variable = source_layer.output_scale
                #     for layer_name in self.source_layers(comp):
                #         source_layer = self._hailo_model.layers[layer_name]
                #         if not isinstance(source_layer.output_scale, tf.Variable):
                #             source_layer.output_scale = source_layer.add_weight(name='output_scale',
                #                                                                 shape=source_layer.output_scale.shape,
                #                                                                 trainable=True,
                #                                                                 initializer=tf.keras.initializers.Constant(
                #                                                                     variable))

        """

    def finalize_global_cfg(self, algo_config):
        pass
