from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DataPath


class SetSignedOutput:
    """
    The algorithm search for layers with singed input and set the data path of their
    predecessors to singed as well
    """

    @classmethod
    def correct_model(cls, model):
        layers_to_correct = cls.get_layers_to_correct(model)
        handaled_layers = set()
        optimization_target = model.optimization_target
        while len(layers_to_correct) > 0:
            lname = layers_to_correct.pop()
            handaled_layers.add(lname)
            layer = model.layers[lname]
            layer.set_output_data_path(DataPath.LAYER_OUT_WEIGHTS)
            layer_config = layer.get_layer_precision_config()
            if layer_config is not None:
                layer.import_precision_config(LayerPrecisionConfig(**layer_config), optimization_target)
            for succ_name in model.flow.successors_sorted(lname):
                succ = model.layers[succ_name]
                input_index = model.flow.predecessors_sorted(succ_name).index(lname)
                succ.set_input_data_path(DataPath.LAYER_IN_WEIGHTS, input_index)
                succ_config = succ.get_layer_precision_config()
                if succ_config is not None:
                    succ.import_precision_config(LayerPrecisionConfig(**succ_config), optimization_target)
                if isinstance(succ, BaseHailoSingleAtomic) and isinstance(succ.atomic_op, BaseNonArithmeticAtomicOp):
                    new_layers_to_correct = [succ_name] + cls.climb_until_no_single_atomic(model, succ_name)
                    layers_to_correct.update(
                        [
                            canidate_lname
                            for canidate_lname in new_layers_to_correct
                            if canidate_lname not in (layers_to_correct | handaled_layers)
                        ]
                    )

    @classmethod
    def get_layers_to_correct(cls, model):
        layers_to_correct = set()
        for lname in model.flow.toposort():
            layer = model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            inputs_data_path = layer.get_inputs_data_path()
            if DataPath.LAYER_IN_WEIGHTS in inputs_data_path:
                index = inputs_data_path.index(DataPath.LAYER_IN_WEIGHTS)
                canidate_lname_to_correct = model.flow.predecessors_sorted(lname)[index]
                lname_to_correct = cls.climb_until_no_single_atomic(model, canidate_lname_to_correct)
                layers_to_correct.update(lname_to_correct)
        return layers_to_correct

    @classmethod
    def climb_until_no_single_atomic(cls, model, lname):
        calidate_Layer_names_to_correct = [lname]
        Layer_names_to_correct = []
        while calidate_Layer_names_to_correct:
            lname = calidate_Layer_names_to_correct.pop()
            layer = model.layers[lname]
            if isinstance(layer, BaseHailoSingleAtomic) and isinstance(layer.atomic_op, BaseNonArithmeticAtomicOp):
                calidate_Layer_names_to_correct += model.flow.predecessors_sorted(lname)
            else:
                Layer_names_to_correct.append(lname)

        return Layer_names_to_correct
