from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, EmulationType

DISCRETE_ACTIVATIONS = (ActivationType.BIASED_DELTA, ActivationType.DELTA)

SUPPORTED_OPS = (ConvStrippedOp,)


class SetFloat64:
    """
    The algorithm search for conv layers with discrete activations and increase the computation to happen in higher
    precision.

    Note that this doesn't effect the layer precision mode (a8_w8_a8, a16_w16_a8, etc...) but just so the convolution
    will be compute with float64 instead of float32.
    """

    @classmethod
    def correct_model(cls, model):
        for lname in cls.get_layers_to_correct(model):
            layer = model.layers[lname]
            for op in layer.atomic_ops:
                if isinstance(op, SUPPORTED_OPS):
                    op.set_type_emulation(EmulationType.DOUBLE)

    @classmethod
    def get_layers_to_correct(cls, model):
        layers_to_correct = []
        for lname in model.flow.toposort():
            layer = model.layers[lname]
            if not isinstance(layer, BaseHailoConv):
                continue
            successors = [
                model.layers[succ]
                for succ in model.flow.successors_sorted(lname)
                if not isinstance(model.layers[succ], BaseHailoNonNNCoreLayer)
            ]
            act_names = [layer.get_activation_name()] + [succ.get_activation_name() for succ in successors]
            if any(act is not None and act in DISCRETE_ACTIVATIONS for act in act_names):
                layers_to_correct.append(lname)
        return layers_to_correct
