from enum import Enum

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DeadLayersRemovalPolicy, LayerType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

# supported layer must include kernel function
SUPPORTED_LAYERS = [BaseHailoConv]


class RemovalAction(Enum):
    remove = "remove"
    const = "const"
    remove_pred = "remove_pred"
    remove_add = "remove_add"
    do_nothing = "do_nothing"


class DeadLayersRemoval(OptimizationAlgorithm):
    """
    Fix negative slope at the APU. Relevant only for Hailo8 where the mantissa of the slope
    in the APU is represented by uint10 and therefore can't represent negative slopes.
    If the algorithm detect monotonic decreasing activation function if flip to
    monotonic increasing function and multiply all the weights by -1
    Note that the activation function must be monotonic
    """

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="Dead Layers Removal", **kwargs)
        self._flow_to_remove = ModelFlow()
        self._layer_to_move_to_const = []

    def _setup(self):
        pass

    def should_skip_algo(self):
        if self._model_config.dead_layers_removal.policy == DeadLayersRemovalPolicy.disabled:
            return True
        return len(self.layers_to_remove + self.layers_to_const) == 0

    def get_algo_config(self):
        return self._model_config.globals

    def _run_int(self):
        if self._model_config.dead_layers_removal.validate_change != DeadLayersRemovalPolicy.disabled:
            ref_input, ref_output = self._infer_model_random_data()
        for layer in self.layers_to_remove + self.layers_to_const:
            if layer.full_name in self._model.layers.keys():
                remove_layer_predecessors = self._remove_layer_predecessors(layer, source_layer=layer)
                remove_layer_successors = self._remove_layer_successors(layer)
                if remove_layer_predecessors and remove_layer_successors:
                    self._remove_layers()
                else:
                    self._flow_to_remove = ModelFlow()
                    self._layer_to_move_to_const = []
                    self._logger.warning(
                        f"{layer.full_name} has almost zero weights and should be removed."
                        "The automatic removal was failed. Try to remove the layer manually"
                        "bacause it might failed the optimization",
                    )
        if self._model_config.dead_layers_removal.validate_change != DeadLayersRemovalPolicy.disabled:
            new_input, new_output = self._infer_model_random_data(ref_input)
            # compare the output of the model before and after the optimization
            if not self._check_similar(ref_output, new_output):
                raise AccelerasException("The model output is not the same after the optimization")

    def _remove_layers(self):
        const_to_remove = []
        for lname in self._flow_to_remove.toposort():
            if lname not in self._model.layers.keys():
                continue
            if lname in self._model.flow.input_nodes or lname in self._model.flow.output_nodes:
                const_to_remove = []
                break
            if isinstance(self._model.layers[lname], HailoConst):
                const_to_remove.append(lname)
            else:
                connect_succ_and_pred = self._model.layers[lname].num_inputs > 1
                self._model.remove_layer(self._model.layers[lname], connect_succ_and_pred=connect_succ_and_pred)
            self._model_config.remove_layer_from_all_configs(lname)
            self._logger.info(f"Remove layer {lname} because it has no effect on the network output")
        for lname in const_to_remove:
            self._model.remove_layer(self._model.layers[lname])
        for layer in self._layer_to_move_to_const:
            if layer.full_name in self._model.layers.keys():
                self._add_const_layer(layer)

        self._flow_to_remove = ModelFlow()
        self._layer_to_move_to_const = []

    def _infer_model_random_data(self, random_input=None):
        """
        infer the self._model with random data
        """
        if random_input is None:
            input_shapes = [[1, *self._model.layers[name].input_shape[1:]] for name in self._model.flow.input_nodes]
            random_input = [np.random.rand(*input_shape) for input_shape in input_shapes]
        output = self._model(random_input)
        return random_input, output

    def _check_similar(self, ref_output, new_output):
        """
        Check the list of np.array new output is similar to the reference output
        """
        return np.all([np.allclose(ref, new, atol=1e-5) for ref, new in zip(ref_output, new_output)])

    def _remove_layer_predecessors(self, layer, source_layer):
        """
        remove all the layer predecessors until there is a split
        """
        preds = self._model.flow.predecessors_sorted(layer.full_name)
        for pred in preds:
            pred_layer = self._model.layers[pred]
            if not self._remove_layers_chain(pred_layer, source_layer):
                return False
        return True

    def _remove_layer_successors(self, layer, input_is_zero=False):
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return False
        action = self.get_removal_action(layer, input_is_zero)
        if action == RemovalAction.remove_pred:
            self._remove_layer_predecessors(layer, layer)
        if action == RemovalAction.remove_add:
            self._add_layer_to_remove(layer)
            return True
        if action == RemovalAction.const:
            self._layer_to_move_to_const.append(layer)
            return True
        elif action in [RemovalAction.remove, RemovalAction.remove_pred]:
            succes = self._model.flow.successors_sorted(layer.full_name)
            self._add_layer_to_remove(layer)
            for succ in succes:
                self._remove_layer_successors(self._model.layers[succ], input_is_zero=True)
            return True
        else:
            return True

    def _add_layer_to_remove(self, layer):
        """
        Add layer to the layer to remove flow. The function verify that the
        layer is connected to the flow_to_remove and add it to the flow if it does.
        """
        self._flow_to_remove.add_node(layer.full_name)
        for node in self._flow_to_remove.nodes:
            if self._model.flow.has_edge(node, layer.full_name):
                self._flow_to_remove.add_edge(node, layer.full_name)
            if self._model.flow.has_edge(layer.full_name, node):
                self._flow_to_remove.add_edge(layer.full_name, node)

    def _remove_layers_chain(self, layer, source_layer):
        """
        remove all layers from the current layer until there is a split.
        Return indication if the layers can be removed
        """
        succes = []
        for suc in self._model.flow.successors_sorted(layer.full_name):
            if suc not in self._flow_to_remove.nodes:
                if suc != source_layer.full_name:
                    succes.append(suc)
        if len(succes) == 0:
            preds = self._model.flow.predecessors_sorted(layer.full_name)
            self._add_layer_to_remove(layer)
            for pred in preds:
                pred_layer = self._model.layers[pred]
                if not self._remove_layers_chain(pred_layer, source_layer):
                    return False
            return True
        else:
            return layer.num_outputs == 1

    def _add_const_layer(self, layer):
        """
        Replace the layer by const layer
        """
        const_suffix = "_const_replacement"
        bias = layer.bias_add_op.bias.numpy()
        weights = np.ones(shape=layer.output_shape[1:]) * bias
        weights = layer.act_op(weights)
        shape = [-1, *weights.shape]
        name = layer.full_name + const_suffix
        self._logger.info(f"Layer {layer.full_name} has near zeros weights. Replacing with with {name}")
        hn = {
            "type": LayerType.CONST_INPUT.value,
            "input": [],
            "output": [suc for suc in self._model.flow.successors_sorted(layer.full_name)],
            "input_shapes": [shape],
            "output_shapes": [shape for suc in self._model.flow.successors_sorted(layer.full_name)],
        }
        new_const_layer = HailoConst.from_hn(name, hn, self._logger)
        new_const_layer.import_weights({"const_data": weights})
        self._model.replace_layer(new_const_layer, layer, use_new_name=True)
        precision_config = self._model_config.precision_config.layers[layer.full_name]
        self._model_config.precision_config.layers[new_const_layer.full_name] = precision_config
        self._model_config.remove_layer_from_all_configs(layer.full_name)
        for pred in self._model.flow.predecessors_sorted(new_const_layer.full_name):
            self._model.flow.remove_edge(pred, new_const_layer.full_name)
        self._model.layers[name].import_precision_config(precision_config, self.optimization_target)

    @property
    def layers_to_remove(self):
        return [
            layer for layer in self._model.layers.values() if self.get_removal_action(layer) == RemovalAction.remove
        ]

    @property
    def layers_to_const(self):
        return [layer for layer in self._model.layers.values() if self.get_removal_action(layer) == RemovalAction.const]

    def get_removal_action(self, layer, input_is_zero=False) -> RemovalAction:
        if isinstance(layer, (HailoElementwiseMult, HailoElementwiseMultOnMac)) and input_is_zero:
            return RemovalAction.remove_pred
        if layer.__class__ is HailoElementwiseAdd and input_is_zero:
            return RemovalAction.remove_add
        if layer.num_inputs > 1:
            return RemovalAction.do_nothing
        if input_is_zero:
            if self._should_be_const(layer):
                return RemovalAction.const
            else:
                return RemovalAction.remove
        else:
            for ltype in SUPPORTED_LAYERS:
                if isinstance(layer, ltype):
                    kernel_max = np.max(np.abs(layer.conv_op.kernel.numpy()))
                    if kernel_max < self._model_config.dead_layers_removal.threshold:
                        if self._should_be_const(layer):
                            return RemovalAction.const
                        else:
                            return RemovalAction.remove
            return RemovalAction.do_nothing

    def _should_be_const(self, layer):
        bias = None
        for op in layer.atomic_ops:
            if isinstance(op, AddBiasOp):
                bias = op.bias.numpy()
                break
        if bias is None:
            return False
        if layer.activation_atomic_op is not None:
            bias = layer.activation_atomic_op(bias)
        return np.max(np.abs(bias)) > self._model_config.dead_layers_removal.threshold

    def finalize_global_cfg(self, algo_config):
        pass
