import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import QuantizationAlgorithms
from hailo_sdk_client.numeric_translator.quantization_tools import get_weights
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendQuantizationException
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.logger.logger import default_logger

PARAMS_SORTER_EQUIV = QuantizationAlgorithms.params_sorter


class ParamsSorterError(Exception):
    pass


class ParamsSorter:
    def __init__(self, hailo_nn):
        self._hailo_nn = hailo_nn
        self._sorted_params = {}
        self._sorted_feature_out = {}
        self._logger = default_logger()

    def sort_params(self, params):
        self._sorted_params = dict(params)
        self._sorted_feature_out = {}
        sorted_layers = set()

        layers_with_groups = set()
        for layer in self._hailo_nn.stable_toposort():
            qgroups = layer.precision_config.quantization_groups
            if (qgroups is not None) and (qgroups > 1):
                layers_with_groups.add(layer)
        if len(layers_with_groups) == 0:
            return None

        for equiv_set in self._hailo_nn.iter_equiv_sets(QuantizationAlgorithms.params_sorter):
            if self._should_skip_param_sort(equiv_set):
                continue
            elif self._should_stop_param_sort(equiv_set):
                break
            else:
                self._sort_equiv(equiv_set)
                sorted_layers.add(equiv_set.source.layer)

        if len(layers_with_groups - sorted_layers) > 0:
            layers_names = [layer.name for layer in (layers_with_groups - sorted_layers)]
            default_logger().debug(f"The layers {layers_names} have quantization groups but can't be sorted")

        return self._sorted_params

    def _should_stop_param_sort(self, equiv_set):
        """
        Specific layers are chosen to be sorted,
        if a chosen layer is not supported - an exception should be raised.
        """
        # TODO: update source
        layer = equiv_set.source.layer
        if not layer.requires_native_weights:
            raise BackendQuantizationException(f"Cannot sort params for layer: {layer.name}")

        return False

    def _should_skip_param_sort(self, equiv_set):
        # TODO: update source
        layer = equiv_set.source.layer
        quantization_groups_list = [
            member.layer.precision_config.quantization_groups for member in [equiv_set.source, *equiv_set.ew_bouncers]
        ]
        if all(
            (q_groups is None) or (q_groups <= 1) for q_groups in quantization_groups_list
        ):  # quantization_groups <= 1:
            return True
        if bool(equiv_set.unsupported) or bool(equiv_set.outputs):
            unsupported = [eq_layer.layer.name for eq_layer in equiv_set.unsupported]
            outputs = [eq_layer.layer.name for eq_layer in equiv_set.outputs]
            self._logger.debug(
                f"Cannot sort params for layer: {layer.name}. Equivalence class contains: "
                f"Unsupported - {unsupported}, Outputs - {outputs}",
            )
            return True
        if bool(equiv_set.skip):
            layer_types = {i.layer.op for i in equiv_set.skip}
            self._logger.debug(
                f"Layer {layer.name} cannot be sorted, but quantization groups are allowed. "
                f"Current equiv class contains layers of type: {layer_types}",
            )
            return True
        if layer.op == LayerType.deconv:
            # TODO: Why is deconv not supported?
            self._logger.debug(f"Layer {layer.name} is deconv, which is not supported as source")
            return True
        return False

    def _sort_equiv(self, equiv_set):
        # TODO: update source
        source = equiv_set.source.layer
        sort_order = self._get_feature_sort_order(equiv_set)
        self._logger.debug(f"Sorting params for based on layer {source.name}")
        self._sort_weights_feature_out(
            source,
            sort_order,
            equiv_set.source.source_indices,
            equiv_set.source.layer_indices,
        )

        for equiv_layer in equiv_set.consumers:
            # TODO: resolve_indices
            source_layer, consumer, source_indices, indices = equiv_layer.get_as_tuple()
            self._logger.debug(f"Sorting params for consumer layer {consumer.name}")
            self._sort_weights_feature_in(source, consumer, sort_order, source_indices, indices, use_bias=False)
        for equiv_layer in equiv_set.featurewise:
            # TODO: resolve_indices
            source_layer, fw, source_indices, indices = equiv_layer.get_as_tuple()
            self._logger.debug(f"Sorting params for featurewise layer {consumer.name}")
            self._sort_weights_feature_in(source, fw, sort_order, source_indices, indices, use_bias=True)
        for equiv_layer in equiv_set.ew_bouncers:
            source_layer, ewb, source_indices, indices = equiv_layer.get_as_tuple()

            # TODO: update source
            self._logger.debug(f"Sorting params for ew_bouncer layer {ewb.name}")
            self._sort_weights_feature_out(ewb, sort_order, source_indices, indices)

    def _sort_weights_feature_out(self, layer, sort_order, source_indices, indices):
        bias, kernel = get_weights(layer.name, self._sorted_params, layer.op)
        resolved_sort_order = self._get_resolved_order(layer, sort_order, source_indices, indices, is_producer=True)

        existing_indices = self._sorted_feature_out.get(layer.name, set())
        if len(set(indices) & existing_indices) > 0:
            raise BackendQuantizationException(f"Cannot sort layer: {layer.name} twice")

        self._sorted_params[f"{layer.name}/kernel:0"] = kernel[..., resolved_sort_order]
        self._sorted_params[f"{layer.name}/bias:0"] = bias[resolved_sort_order]
        self._sorted_feature_out[layer.name] = existing_indices | set(indices)

    def _get_feature_sort_order(self, equiv_set):
        layer = equiv_set.source.layer
        _, kernel = get_weights(layer.name, self._sorted_params, layer.op)

        if len(kernel.shape) == 2:
            channel_max_weights = np.max(np.abs(kernel), axis=(0,))
        else:
            channel_max_weights = np.max(np.abs(kernel), axis=(0, 1, 2))
        return np.argsort(channel_max_weights)

    def _sort_weights_feature_in(self, source, consumer, sort_order, source_indices, indices, use_bias):
        # TODO: is source.op intentional here?
        bias, consumer_kernel = get_weights(consumer.name, self._sorted_params, source.op)
        is_conv_to_dense = (len(consumer_kernel.shape) == 2) and (source.op == LayerType.conv)
        new_sort_order = self._get_resolved_order(
            consumer,
            sort_order,
            source_indices,
            indices,
            is_conv_to_dense=is_conv_to_dense,
        )

        if consumer.op == LayerType.dense:
            self._sorted_params[f"{consumer.name}/kernel:0"] = consumer_kernel[new_sort_order, :]
        else:
            self._sorted_params[f"{consumer.name}/kernel:0"] = consumer_kernel[:, :, new_sort_order, :]
            if use_bias:
                self._sorted_params[f"{consumer.name}/bias:0"] = bias[new_sort_order]

    def _get_resolved_order(
        self,
        layer_to_resolve,
        sort_order,
        source_indices,
        consumer_indices,
        is_conv_to_dense=False,
        is_producer=False,
    ):
        # take indices from sort_order based on existing indices in source_indices
        sorted_source_indices = [i for i in sort_order if i in source_indices]
        # sort consumer_indices based on sort_order
        source_to_consumer_indices = dict(zip(source_indices, consumer_indices))
        sorted_consumer_indices = np.array([source_to_consumer_indices[i] for i in sorted_source_indices])

        # apply order change to global order
        if is_conv_to_dense:
            # layer has spatial properties, e.g. conv-to-dense
            pred = self._get_dense_predecessor(layer_to_resolve)
            new_sort_order = np.arange(pred.output_features)
            new_sort_order[np.array(consumer_indices)] = sorted_consumer_indices
            flatten = np.arange(pred.output_height * pred.output_width * pred.output_features)
            reshaped = np.reshape(flatten, [pred.output_height, pred.output_width, pred.output_features])
            reshaped_sorted = reshaped[:, :, new_sort_order]
            new_sort_order = np.reshape(reshaped_sorted, pred.output_height * pred.output_width * pred.output_features)
        elif is_producer:
            new_sort_order = np.arange(layer_to_resolve.output_features)
            new_sort_order[np.array(consumer_indices)] = sorted_consumer_indices
        else:
            new_sort_order = np.arange(layer_to_resolve.input_features)
            new_sort_order[np.array(consumer_indices)] = sorted_consumer_indices

        return new_sort_order

    def _get_dense_predecessor(self, dense):
        preds = list(self._hailo_nn.predecessors(dense))
        if len(preds) != 1:
            raise BackendQuantizationException(f"Dense layer '{dense.original_names[0]}' had more than 1 predecessor")
        return preds[0]
