import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp


class ReorderOp(BaseAtomicOp):
    """

    This op changes the order of the input channels. Not all channels will be passes out.
    For each output channel we can choose to pass one of the input channels or the constant value '1'.
    For passing the value '1' we need to get the quantized '1' value, but there is no garentee that 1 will be passed
    without a loss.
    The value of '1' is used in the feature multiplier layer to pass the channel without multipication, since '1' is
    not garenteed to be lossles, the output feature is also not garenteed to pass without loss.

    Args:
        recipe - the recipe is a "map" that will tell the op which input feature to output in the index of the output
        feture. If the recipe value is equal to number_of_input_features, the output will be the constant '1'.
        For example:
            input_features = 6
            output_features = 5
            recipe = [1, 1, 0, 6, 2]

            So the output will be [feature1, feature1, feature0, 1, feature2]

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._recipe = None

    def get_config(self):
        """
        Returns the configuration of the operation as a dictionary.

        Returns:
            dict: Configuration of the operation.
        """
        config = super().get_config()
        config.update({"recipe": self.recipe})
        return config

    @classmethod
    def from_config(cls, config):
        """
        Creates an instance of the operation from the given configuration.

        Args:
            config (dict): Configuration of the operation.

        Returns:
            ReorderOp: An instance of the operation.
        """
        valid_kwargs = {
            "name": config.pop("name"),
        }
        instance = cls(**valid_kwargs)

        for key, value in config.items():
            if key in instance.__dict__:
                setattr(instance, key, value)

        return instance

    @property
    def recipe(self):
        return self._recipe

    def import_weights(self, recipe):
        self._recipe = recipe

    @recipe.setter
    def recipe(self, value):
        self._recipe = value

    def _compute_output_shape(self, input_shape):
        shape = [*input_shape[:-1], len(self.recipe)]
        return shape

    def create_weight_quant_element(self, **kwargs):
        # This op is used for emulation only, hence, dosent have any parameters.
        # The recipe parameters will be passed throw the feature_multiplier layer in the name power_table.
        pass

    def call_hw_sim(self, inputs, **kwargs):
        """
        Since we want to return quantized 1 if the recipe value is features_in, we will pad the last channel with the
        value of the quantized 1.
        I asume here that all channels input scales are equal.
        """
        ones_q = self.get_quantized_one()

        paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, 1]])
        padded_input = tf.pad(inputs[0], paddings, mode="CONSTANT", constant_values=ones_q)
        return tf.gather(padded_input, self.recipe, axis=-1)

    def get_quantized_one(self):
        one_scale = tf.reduce_mean(self.input_scales[0])
        ones_q = tf.add(tf.divide(1.0, one_scale), self.input_zero_points[0])
        ones_q_lossy = self.output_lossy_element(ones_q)
        return ones_q_lossy

    def call_native(self, inputs, **kwargs):
        paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, 1]])
        padded_input = tf.pad(inputs[0], paddings, mode="CONSTANT", constant_values=1)
        output = tf.gather(padded_input, self.recipe, axis=-1)
        return output

    def export_weights(self):
        return dict()

    def create_hw_params(self, **kwargs):
        pass

    def enforce_encoding(self):
        """
        The enforce encoding is supporting vector scales as input. We will pass each input scale to its
        coresponding output. We will also pass the '1' scale if it is used.
        """
        if self.input_scale_is_scalar(0):
            self.output_scale = tf.reduce_mean(self.input_scales[0])
        else:
            one_scale = tf.expand_dims(self.input_scales[0][0], axis=0)
            padded_input_scales = tf.concat([self.input_scales[0], one_scale], axis=0)
            self.output_scale = tf.gather(padded_input_scales, self.recipe, axis=-1)
        self.output_zero_point = self.input_zero_points[0]

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
