import numpy as np

from hailo_sdk_client.post_fuser.algorithms.normalization_optimizer import NormalizationOptimizer
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType


class MulByScalarAfterMatMulFolding(NormalizationOptimizer):
    NAME = "mul_by_scalar_after_matmul_folding"

    def _run_int(self):
        self._fold_post_layer_normalization_layers()
        self._fold_mul_by_scalar_after_matmul()
        self._model.update_input_lists()

    def export_statistics(self):
        pass

    def _fold_mul_by_scalar_after_matmul(self):
        """
        Allow folding of mul by scalar (normalization with bias==0 and kernel==kernel[0]) layers to valid preds.
        """
        layers_to_remove = []
        for layer in list(self._model):
            if layer.op == LayerType.normalization:
                succs = list(self._model.successors(layer))
                preds = list(self._model.predecessors(layer))

                can_fold, matmul_pred, scalar = self._can_fold_mul_by_scalar_after_matmul(layer, preds)
                if not can_fold:
                    continue

                matmul = preds[0]

                for name in layer.original_names:
                    matmul_pred.add_original_name(name)

                for succ in succs:
                    if layer.name in matmul.outputs:
                        matmul.replace_output_shape(layer.name, layer.output_shape)
                        matmul.replace_output_index(layer.index, succ.index)
                        matmul.replace_output_layer(layer.name, succ.name)
                    else:
                        matmul.append_output_shape(layer.output_shape)
                        matmul.append_output_index(succ.index)
                        matmul.append_output_layer(succ.name)

                    succ.replace_input_shape(layer.name, matmul.output_shape)
                    succ.replace_input_index(layer.index, matmul.index)
                    succ.replace_input_layer(layer.name, matmul.name)

                    self._model.remove_edge(layer, succ)
                    self._model.add_edge(matmul, succ)

                # move params from scalar mul to matmul predecessor
                src_params = dict(self._params[layer.name])
                features = matmul_pred.output_features
                src_params["kernel:0"] = np.ones([1, 1, features, 1]) * scalar
                src_params["bias:0"] = np.zeros([features])
                self._fold_bn_or_normalization_params(layer, matmul_pred, src_params, pre_layer_bn=False)
                self._model.remove_edge(matmul, layer)

                for i, output in enumerate(self._model.net_params.output_layers_order):
                    if layer.name == output:
                        self._model.net_params.output_layers_order[i] = matmul.name

                self._logger.debug(
                    f"Folded {layer.op.value} layer {layer.name} onto predecessor layer {matmul_pred.name}",
                )
                layers_to_remove.append(layer)

        for layer_to_remove in layers_to_remove:
            self._model.remove_layer(layer_to_remove)

    def _can_fold_mul_by_scalar_after_matmul(self, layer, preds):
        if len(preds) > 1:
            return False, None, None

        pred = preds[0]

        if pred.op != LayerType.matmul:
            return False, None, None

        if pred.transpose_output_width_features or pred.spatial_flatten_output:
            return False, None, None

        if layer.activation != ActivationType.linear or layer.bn_enabled or layer.ew_add_enabled:
            return False, None, None

        if layer.dynamic_weights or layer.transpose_output_width_features or layer.spatial_flatten_output:
            return False, None, None

        # can't fold normalization over a layer that has more than one output (the second output expects
        # non-normalized tensor)
        if len(list(self._model.successors(pred))) > 1:
            return False, None, None

        # matmul activation is always linear
        if pred.activation != ActivationType.linear:
            return False, None, None

        src_params = self._params[layer.name]

        bias = src_params["bias:0"].flatten()
        if np.any(bias != 0):
            return False, None, None

        kernel = src_params["kernel:0"].flatten()
        scalar = kernel[0]
        if np.any(kernel != scalar):
            return False, None, None

        matmul_preds = list(self._model.predecessors(pred))
        for matmul_pred in matmul_preds:
            if matmul_pred.op not in [
                LayerType.conv,
                LayerType.dw,
                LayerType.deconv,
                LayerType.dense,
                LayerType.normalization,
            ]:
                continue

            if matmul_pred.op not in [LayerType.dense, LayerType.normalization] and matmul_pred.layer_disparity > 1:
                continue

            if matmul_pred.dynamic_weights:
                continue

            if matmul_pred.transpose_output_width_features or matmul_pred.spatial_flatten_output:
                continue

            # can't fold normalization over a layer that has more than one output (the second output expects
            # non-normalized tensor)
            if len(list(self._model.successors(matmul_pred))) > 1:
                continue

            if matmul_pred.bn_enabled or matmul_pred.ew_add_enabled:
                continue

            if not self._is_valid_pred_activation(matmul_pred.activation, scalar):
                continue

            return True, matmul_pred, scalar

        return False, None, None

    @staticmethod
    def _is_valid_pred_activation(activation, scalar):
        if activation == ActivationType.linear:
            return True

        if scalar > 0 and activation in [ActivationType.relu, ActivationType.leaky, ActivationType.biased_delta]:
            return True

        return False
