#!/usr/bin/env python
from collections import OrderedDict

import numpy as np
import scipy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DeadChannelsRemovalPolicy,
    LayerHandlerType,
    QuantizationAlgorithms,
)
from hailo_sdk_client.numeric_translator.quantization_tools import get_weights
from hailo_sdk_client.post_fuser.algorithms.exceptions import DeadChannelsRemovalException, IllegalRemovalAction
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType


class DeadChannelsRemoval(FuserAlgorithm):
    NAME = "dead_channels_removal"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch, **kwargs)
        self._real_dead_channels_by_layer = None
        self._indices_to_remove = None

    def get_algo_config(self):
        return self._model_config.dead_channels_removal

    def _setup(self):
        self._real_dead_channels_by_layer = {}
        self._indices_to_remove = {
            "input": OrderedDict(),
            "output": OrderedDict(),
            "special_treatment": OrderedDict(),
        }

    def _run_int(self):
        self._model.update_input_lists()
        self._find_channels_to_remove_from_graph()
        self._remove_channels_from_graph()
        self._model.calculate_shapes()

    def should_skip_algo(self):
        return self.get_algo_config().policy == DeadChannelsRemovalPolicy.disabled

    def _find_channels_to_remove_from_graph(self):
        for equiv_set in self._model.iter_equiv_sets(QuantizationAlgorithms.dead_channels_removal):
            has_unsupported_slice = self._has_slice_with_stride(equiv_set.transparents)
            if equiv_set.outputs or equiv_set.unsupported or has_unsupported_slice:
                self._logger.debug(
                    f"Can't remove dead channels from equiv of {equiv_set.source.layer.name} because it "
                    f"has unsupported layers",
                )
                continue
            self._find_channels_to_remove_from_equiv(equiv_set)

    def _remove_channels_from_graph(self):
        for layer, indices_to_remove in self._indices_to_remove["output"].items():
            self._logger.debug(f"Removing {len(indices_to_remove)} channels from layer {layer.name}")
            self._change_feature_out(layer, indices_to_remove)
        for layer, indices_to_remove in self._indices_to_remove["input"].items():
            handler_type = layer.get_dead_channels_removal_handler_type().handler_type
            if handler_type == LayerHandlerType.consumer:
                self._change_feature_in(layer, indices_to_remove, remove_bias=False)
            elif handler_type == LayerHandlerType.featurewise:
                self._change_feature_in(layer, indices_to_remove, remove_bias=True)
            else:
                raise DeadChannelsRemovalException(
                    f"Unexpected layer type for input channels removal {layer.op.value},"
                    f" handler_type {handler_type}",
                )
        for layer, indices_to_remove in self._indices_to_remove["special_treatment"].items():
            if layer.op == LayerType.slice:
                self._change_hn_node_slice(layer, indices_to_remove)
            else:
                raise DeadChannelsRemovalException(
                    f"Unexpected layer type for special_treatment channels removal {layer.op.value}",
                )

    @staticmethod
    def _has_slice_with_stride(transparents):
        for transparent in transparents:
            if (transparent.layer.op == LayerType.slice) and (transparent.layer.features_slice[2] != 1):
                return True
        return False

    def _change_hn_node_slice(self, layer, indices_to_remove):
        if layer.features_slice[2] != 1:
            raise DeadChannelsRemovalException(
                f"Slice {layer.name} with stride is not supported with dead channels removal",
            )
        less_than = np.where(indices_to_remove < layer.features_slice[0])[0]
        in_range = (
            np.where((indices_to_remove >= layer.features_slice[0]) & (indices_to_remove < layer.features_slice[1]))
        )[0]
        layer.features_slice[0] = layer.features_slice[0] - len(less_than)
        layer.features_slice[1] = layer.features_slice[1] - (len(less_than) + len(in_range))
        self._change_hn_node_feature_out(in_range, layer, has_kernel=False)

    def _get_const_params(self, layer_name, activation):
        if activation == ActivationType.leaky:
            return self._params[layer_name].get("leaky_alpha")
        if activation == ActivationType.threshold:
            return float(self._params[layer_name].get("activation_threshold"))
        if activation == ActivationType.biased_delta:
            return self._params[layer_name].get("activation_delta_bias")
        if activation == ActivationType.clip:
            return self._params[layer_name].get("clip_min")
        if activation == ActivationType.less:
            return self._params[layer_name].get("activation_less_values")
        if activation == ActivationType.hardsigmoid:
            return self._params[layer_name].get("hardsigmoid_alpha"), self._params[layer_name].get("hardsigmoid_beta")
        if activation == ActivationType.greater:
            return self._params[layer_name].get("activation_greater_values")
        if activation == ActivationType.swish:
            return self._params[layer_name].get("swish_beta")

    def _find_channels_to_remove_from_equiv(self, equiv_set):
        """
        Wanted changes are stored in self._indices_to_remove
        """
        dead_channels_mask = self._get_dead_channels_mask_of_equiv(equiv_set)
        if not np.any(dead_channels_mask):
            return

        source_layer = equiv_set.source.layer
        self._logger.debug(
            f"Changing dead params for layer {source_layer.name} because there are "
            f"{np.sum(dead_channels_mask)} dead channels",
        )
        if all(dead_channels_mask):
            self._logger.debug(f"All the channels of layer {source_layer.name} are dead - we will keep the last one")
            dead_channels_mask[0] = False

        self._get_indices_to_remove(equiv_set.source, dead_channels_mask, "output")
        for equiv_consumer in equiv_set.consumers:
            self._get_indices_to_remove(equiv_consumer, dead_channels_mask, "input")
        for equiv_featurewise in equiv_set.featurewise:
            self._get_indices_to_remove(equiv_featurewise, dead_channels_mask, "input")
        for ew_bouncer in equiv_set.ew_bouncers:
            self._get_indices_to_remove(ew_bouncer, dead_channels_mask, "output")
        for transparent in equiv_set.transparents:
            if transparent.layer.op in [LayerType.slice]:
                self._get_indices_to_remove(transparent, dead_channels_mask, "special_treatment")

    def _get_dead_channels_mask_of_equiv(self, equiv_set):
        source_layer = equiv_set.source
        dead_channels_mask = self._get_dead_channels_mask_of_layer(source_layer)
        if not np.any(dead_channels_mask):
            return dead_channels_mask
        for featurewise_equiv in equiv_set.featurewise:
            featurewise = featurewise_equiv.layer
            bias = self._params[featurewise.name]["bias"]
            bias_after_activation = self._get_bias_after_activation(bias, featurewise, featurewise.name)
            bias_dead_channels_mask = bias_after_activation == 0
            bias_on_relevant_indices_mask = bias_dead_channels_mask[np.array(featurewise_equiv.layer_indices)]
            dead_channels_mask[np.array(featurewise_equiv.source_indices)] &= bias_on_relevant_indices_mask
        for ew_bouncer in equiv_set.ew_bouncers:
            dead_channels_mask = self._get_dead_channels_mask_of_layer(ew_bouncer, dead_channels_mask)
        return dead_channels_mask

    def _get_dead_channels_mask_of_layer(self, equiv_layer, pred_dead_channels_mask=None):
        """
        This function returns the dead channels of the layer that we can remove from it (based on the dead channels of
        this layer and the successors)

        Args:
            equiv_layer: the current layer
            pred_dead_channels_mask: the dead channels of the previous layers

        Returns: list of Booleans (mask)

        """
        layer = equiv_layer.layer
        bias, kernel = get_weights(layer.name, self._params, layer.op)

        if len(kernel.shape) == 2:
            channel_min_weights = np.min(np.abs(kernel), axis=(0,))
        else:
            channel_min_weights = np.min(np.abs(kernel), axis=(0, 1, 2))
        kernel_dead_channels_mask = channel_min_weights == 0

        bias_after_activation = self._get_bias_after_activation(bias, layer, layer.name)
        bias_dead_channels_mask = bias_after_activation == 0
        dead_channels_mask = np.logical_and(kernel_dead_channels_mask, bias_dead_channels_mask)

        # there is a difference between the dead channels and the channels that are removed by algo. Here we save the
        # all the channels that are dead with respect to a layer even if we cant remove them from the kernel.
        if layer.name not in self._real_dead_channels_by_layer:
            self._real_dead_channels_by_layer[layer.name] = np.where(dead_channels_mask)[0]
        if pred_dead_channels_mask is not None:
            pred_dead_channels_mask = pred_dead_channels_mask[np.array(equiv_layer.source_indices)]
            dead_channels_mask[np.array(equiv_layer.layer_indices)] &= pred_dead_channels_mask

        return dead_channels_mask

    def _get_bias_after_activation(self, bias, layer, layer_name):
        # return the bias after running activation on it
        activation = ActivationType.linear
        if hasattr(layer, "activation"):
            activation = layer.activation
        const = self._get_const_params(layer_name, activation)
        if isinstance(const, tuple) and len(const) == 2:
            bias_after_activation = get_function_by_activation(activation)(bias, const[0], const[1])
        elif const is not None:
            bias_after_activation = get_function_by_activation(activation)(bias, const)
        else:
            bias_after_activation = get_function_by_activation(activation)(bias)

        return bias_after_activation

    def _get_indices_to_remove(self, equiv_layer, dead_channels_mask, input_or_output):
        layer, source_indices, layer_indices = self._get_updated_equiv_set(equiv_layer, input_or_output)
        consumer_dead_channels_mask = dead_channels_mask[source_indices]
        if not any(consumer_dead_channels_mask):
            return
        source_indices_to_remove = source_indices[consumer_dead_channels_mask]
        indices_to_remove = layer_indices[np.where(np.isin(source_indices, source_indices_to_remove))]
        self._indices_to_remove[input_or_output][layer] = np.concatenate(
            [self._indices_to_remove[input_or_output][layer], indices_to_remove],
        )

    def _change_feature_out(self, layer, indices_to_remove):
        self._change_hn_node_feature_out(indices_to_remove, layer)
        self._change_params_feature_out(indices_to_remove, layer)

    def _change_params_feature_out(self, indices_to_remove, layer):
        # this function changes the weights of the out_feature
        bias, kernel = get_weights(layer.name, self._params, layer.op)

        if len(kernel.shape) == 2:
            self._params[f"{layer.name}/kernel:0"] = np.delete(kernel, indices_to_remove, axis=1)
        else:
            self._params[f"{layer.name}/kernel:0"] = np.delete(kernel, indices_to_remove, axis=3)
        self._params[f"{layer.name}/bias:0"] = np.delete(bias, indices_to_remove)

    def _change_hn_node_feature_out(self, indices_to_remove, layer, has_kernel=True):
        num_channels = layer.output_shape[-1] - len(indices_to_remove)
        for output_shape, layer_out_name in zip(layer.output_shapes, layer.outputs):
            # change the outputs shape for layer.
            self._logger.debug(
                f"Changed out_shape of {layer.name} from {output_shape[-1]} to {num_channels} because of"
                f" {layer_out_name}",
            )
            output_shape[-1] = num_channels
        if has_kernel:
            layer.kernel_shape[-1] = num_channels

    def _get_consumer_predecessor(self, consumer):
        preds = list(self._model.predecessors(consumer))
        if consumer.ew_add_enabled:
            preds = [pred for pred in preds if pred not in consumer.ew_add_connections]
        if len(preds) != 1:
            raise DeadChannelsRemovalException(
                f"Consumer layer '{consumer.original_names[0]}' had more than 1 predecessor",
            )
        return preds[0]

    def _change_feature_in(self, layer, indices_to_remove, remove_bias):
        predecessor = self._get_consumer_predecessor(layer)
        num_channels = self._change_params_feature_in(indices_to_remove, layer, predecessor, remove_bias)
        self._change_hn_node_feature_in(layer, predecessor, num_channels)

    def _change_params_feature_in(self, indices_to_remove, consumer, predecessor, remove_bias):
        bias, kernel = get_weights(consumer.name, self._params, consumer.op)

        shape_index = -2 if consumer.op != LayerType.batch_norm else -1
        num_channels = kernel.shape[shape_index] - len(indices_to_remove)

        if len(kernel.shape) == 2:
            consumer_index = predecessor.outputs.index(consumer.name)
            if len(predecessor.output_shapes[consumer_index]) == 4:
                num_channels = self._change_conv_to_dense_feature_in(
                    predecessor,
                    kernel,
                    consumer.name,
                    indices_to_remove,
                )
            else:
                self._params[f"{consumer.name}/kernel:0"] = np.delete(kernel, indices_to_remove, axis=0)
        else:
            self._params[f"{consumer.name}/kernel:0"] = np.delete(kernel, indices_to_remove, axis=2)

        if remove_bias:
            self._params[f"{consumer.name}/bias:0"] = np.delete(bias, indices_to_remove)

        return num_channels

    def _change_hn_node_feature_in(self, successor, predecessor, num_of_channels):
        if num_of_channels is None:
            return
        axis = -1 if successor.op == LayerType.batch_norm else -2
        self._logger.debug(
            f"Changed in_shape of {successor.name} because of {predecessor.name} from "
            f"{successor.kernel_shape[axis]} to {num_of_channels}",
        )
        ind_input = successor.inputs.index(predecessor.name)
        successor.input_shapes[ind_input][-1] = num_of_channels
        successor.kernel_shape[axis] = num_of_channels

    def _get_updated_equiv_set(self, equiv_layer, input_or_output):
        layer = equiv_layer.layer
        source_indices = np.array(equiv_layer.source_indices)
        layer_indices = np.array(equiv_layer.layer_indices)
        if layer not in self._indices_to_remove[input_or_output]:
            self._indices_to_remove[input_or_output][layer] = np.array([], dtype="int32")
        else:
            source_indices = source_indices[~np.isin(layer_indices, self._indices_to_remove[input_or_output][layer])]
            layer_indices = layer_indices[~np.isin(layer_indices, self._indices_to_remove[input_or_output][layer])]
        return layer, source_indices, layer_indices

    def _change_conv_to_dense_feature_in(self, layer, kernel, layer_param_key, indices_to_remove):
        _, h, w, f = layer.output_shape
        reshaped_kernel = np.reshape(kernel, (h, w, -1, kernel.shape[-1]))
        changed_kernel = np.delete(reshaped_kernel, indices_to_remove, axis=2)
        reshaped_back_kernel = np.reshape(changed_kernel, (-1, kernel.shape[-1]))
        self._params[f"{layer_param_key}/kernel:0"] = reshaped_back_kernel
        return reshaped_back_kernel.shape[0]


def get_function_by_activation(activation_type):
    function_by_activation = {
        ActivationType.linear: linear,
        ActivationType.relu: relu,
        ActivationType.relu6: relu_n,
        ActivationType.leaky: leaky_relu,
        ActivationType.elu: elu,
        ActivationType.sigmoid: sigmoid,
        ActivationType.exp: np.exp,
        ActivationType.tanh: np.tanh,
        ActivationType.threshold: relu,
        ActivationType.biased_delta: biased_delta,
        ActivationType.softplus: softplus,
        ActivationType.silu: silu,
        ActivationType.swish: swish,
        ActivationType.mish: mish,
        ActivationType.relu1: relu_n,
        ActivationType.less: less,
        ActivationType.clip: relu,
        ActivationType.log: np.log,
        ActivationType.sqrt: np.sqrt,
        ActivationType.gelu: gelu,
        ActivationType.inv_pos: np.reciprocal,
        ActivationType.inv_sqrt: inv_sqrt,
        ActivationType.minus_inv_pos: minus_inv_pos,
        ActivationType.hardswish: hardswish,
        ActivationType.hardsigmoid: hardsigmoid,
        ActivationType.delta: delta,
        ActivationType.greater: greater,
    }

    if activation_type in function_by_activation:
        return function_by_activation[activation_type]

    raise IllegalRemovalAction(f"{activation_type} activation is not supported with dead channels removal")


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def relu(x, threshold=0):
    return x * (x > threshold)


def relu_n(x):
    return x * (x > 0)


def linear(x):
    return x


def leaky_relu(x, alpha):
    return relu(x) + x * alpha * (x < 0)


def elu(x):
    return relu(x) + (np.exp(x) - 1.0) * (x < 0)


def biased_delta(x, delta=-1.0):
    return delta * np.sign(np.abs(x))


def delta(x):
    return 1 - np.sign(np.abs(x))


def softplus(x):
    return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)


def silu(x):
    return x * sigmoid(x)


def swish(x, const):
    return x * sigmoid(const * x)


def mish(x):
    return x * np.tanh(softplus(x))


def less(x, const):
    return np.less(x, const).astype(np.float32)


def greater(x, const):
    return np.greater(x, const).astype(np.float32)


def gelu(x):
    return 0.5 * x * (1 + scipy.special.erf(x / np.sqrt(2)))


def inv_sqrt(x):
    return np.reciprocal(np.sqrt(x))


def minus_inv_pos(x):
    return np.reciprocal(-x)


def hardswish(x):
    return x * np.clip(x + 3, 0.0, 6.0) / 6


def hardsigmoid(x, alpha, beta):
    return np.clip(alpha * x + beta, 0.0, 1.0)
