import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp


class FeaturePermuteOp(BaseNonArithmeticAtomicOp):
    """
    Modify features order of a channel.
    Implemented using `tf.gather` on the features' axis.
    feature_order has to be explicitly set before calling this layer.
    feature_order can be called only after the layer has been built
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._feature_order = None
        self._inverse_order = None

    @property
    def feature_order(self):
        """
        Get the new feature order of the channel
        Returns:
            Constant tensor with new feature order
        """
        return self._feature_order

    def _compute_output_shape(self, input_shape):
        return input_shape

    @feature_order.setter
    def feature_order(self, new_order):
        """
        Set feature order for the permute
        Args:
            new_order: np array, list or tuple with the new features' indices order

        """
        if len(new_order) != self.input_shape[-1]:
            raise ValueError("Feature order must be the same length as the layer's features")
        if set(new_order) != set(range(self.input_shape[-1])):
            raise ValueError("Feature order must contain all features")
        self._feature_order = np.array(new_order, np.int32)
        self._generate_reverse_order()

    def call_native(self, inputs, **kwargs):
        return tf.gather(inputs[0], self._feature_order, axis=-1)

    def export_independent_params(self):
        return {
            "feature_order": self._feature_order,
        }

    def import_independent_params(self, params):
        self._feature_order = params["feature_order"]
        self._generate_reverse_order()

    def _generate_reverse_order(self):
        self._inverse_order = np.empty_like(self._feature_order)
        self._inverse_order[self._feature_order] = np.arange(len(self._feature_order))

    def enforce_encoding(self):
        self.output_scale = tf.gather(self.input_scales[0], self.feature_order)
        self.output_zero_point = self.input_zero_points[0]

    def backward_encoding(self):
        if len(tf.convert_to_tensor(self.output_scale).shape) == 0:
            self.input_scales[0] = self.output_scale
        else:
            self.input_scales[0] = tf.gather(self.output_scale, self._inverse_order)
        self.input_zero_points[0] = self.output_zero_point
