from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_a16_pre_act_sum import HailoConvA16PreActSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_a16_quant_weight_group import (
    HailoConvA16W8QuantWeightGroup,
)
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplitPixels
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum_a16_pre_act_sum import HailoReduceSumA16PreActSum
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ZP_LOW_SPLIT_PRECISION_PIXEL


class SetOutputSplitPrecisionZP:
    """
    The algorithm search for layers with split precision input and set the output zero point of their predecessors to
    acordingly.
    """

    @classmethod
    def correct_model(cls, model):
        layers_to_correct = cls.get_layers_to_correct(model)
        for lname in layers_to_correct:
            layer = model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer.output_split_precision_zp = ZP_LOW_SPLIT_PRECISION_PIXEL

    @classmethod
    def get_layers_to_correct(cls, model):
        layers_to_correct = set()
        for lname in model.flow.toposort():
            layer = model.layers[lname]
            if isinstance(
                layer,
                (
                    HailoConvA16PreActSum,
                    HailoPrecisionSplitPixels,
                    HailoReduceSumA16PreActSum,
                    HailoConvA16W8QuantWeightGroup,
                ),
            ):
                canidate_lname_to_correct = model.flow.predecessors_sorted(lname)[0]
                lname_to_correct = cls.climb_until_no_single_atomic(model, canidate_lname_to_correct)
                layers_to_correct.update(lname_to_correct)
        return layers_to_correct

    @classmethod
    def climb_until_no_single_atomic(cls, model, lname):
        calidate_Layer_names_to_correct = [lname]
        Layer_names_to_correct = []
        while calidate_Layer_names_to_correct:
            lname = calidate_Layer_names_to_correct.pop()
            layer = model.layers[lname]
            if isinstance(layer, BaseHailoSingleAtomic) and isinstance(layer.atomic_op, BaseNonArithmeticAtomicOp):
                calidate_Layer_names_to_correct += model.flow.predecessors_sorted(lname)
            else:
                Layer_names_to_correct.append(lname)

        return Layer_names_to_correct
