import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInferenceError


class SliceOp(BaseNonArithmeticAtomicOp):
    """
    Extracts a slice from a tensor.
    The args  height_slice, width_slice, features_slice  are all a " slice object" represented by a tuple size 3 where
    (begin, end, stride).

    Now we assume that the stride = 1.

    for input_tensor of shape = (batch,28,28,3) and height_slice = [1, 3, 1] width_slice= [1, 3, 1]
                                                                    features_slice=[1, 3, 1]
    input_tensor[:,
                height_slice[0]:height_slice[1]:height_slice[2],
                width_slice[0]:width_slice[1]: width_slice[2],
                features_slice[0]:features_slice[1]:features_slice[2]]

    we will get result shape: shape = (batch,2, 2,2)

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        height_slice: tuple = None,
        width_slice: tuple = None,
        features_slice: tuple = None,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        """
        Args:
            height_slice: tuple of  (begin, end, stride) indicates what slice of to take from the height dimension
                if None it takes the full slice.
            width_slice: a tuple of  (begin, end, stride) indicates what slice of to take from the width dimension
                                     if None it takes the full slice.
            features_slice: a tuple of (begin, end, stride) indicates what slice of to take from the features dimension
                             if None it takes the full slice.

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        # the default is None
        self._height_slice = self._default_height_slice = height_slice
        self._width_slice = self._default_width_slice = width_slice
        self._features_slice = self._default_features_slice = features_slice

    def get_config(self):
        """
        Returns the configuration of the operation as a dictionary.

        Returns:
            dict: Configuration of the operation.
        """
        config = super().get_config()
        config.update(
            {
                "height_slice": self.height_slice,
                "width_slice": self.width_slice,
                "features_slice": self.features_slice,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        """
        Creates an instance of the operation from the given configuration.

        Args:
            config (dict): Configuration of the operation.

        Returns:
            SliceOp: An instance of the operation.
        """
        valid_kwargs = {
            "name": config.pop("name"),
            "height_slice": config.pop("height_slice", None),
            "width_slice": config.pop("width_slice", None),
            "features_slice": config.pop("features_slice", None),
        }
        instance = cls(**valid_kwargs)

        for key, value in config.items():
            if key in instance.__dict__:
                setattr(instance, key, value)

        return instance

    def call_native(self, inputs, **kwargs):
        op = inputs[0][
            :,
            self.height_slice[0] : self.height_slice[1] : self.height_slice[2],
            self.width_slice[0] : self.width_slice[1] : self.width_slice[2],
            self.features_slice[0] : self.features_slice[1] : self.features_slice[2],
        ]

        return op

    def call_bit_exact(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    @property
    def height_slice(self):
        return self._height_slice

    @property
    def width_slice(self):
        return self._width_slice

    @property
    def features_slice(self):
        return self._features_slice

    def _build(self, input_shape):
        self._fill_slices(input_shape)

    def _compute_output_shape(self, input_shape):
        self._fill_slices(input_shape)
        height_size = self._calc_size(self._height_slice, input_shape[1])
        width_size = self._calc_size(self._width_slice, input_shape[2])
        features_size = self._calc_size(self._features_slice, input_shape[3])
        shape = [input_shape[0], height_size, width_size, features_size]
        return shape

    def _calc_size(self, index_tuple, max_size):
        start, end, interval = index_tuple
        start = start if start >= 0 else start + max_size
        end = end if end >= 0 else end + max_size
        return len(range(start, end, interval))

    def _fill_slices(self, input_shape):
        """
        fills the slice tuple based on inputs shape if None it given.
        checks the params are compatible with the input_shape.
        """
        if self._default_height_slice is None:
            self._height_slice = tuple([0, input_shape[1], 1])
        elif self.height_slice[1] is None:
            self._height_slice = (self._height_slice[0], input_shape[1], self._height_slice[2])
        if self._default_width_slice is None:
            self._width_slice = tuple([0, input_shape[2], 1])
        elif self.width_slice[1] is None:
            self._width_slice = (self._width_slice[0], input_shape[2], self._width_slice[2])
        if self._default_features_slice is None:
            self._features_slice = tuple([0, input_shape[3], 1])
        elif self.features_slice[1] is None:
            self._features_slice = (self._features_slice[0], input_shape[3], self._features_slice[2])
        if (
            self.height_slice[1] > input_shape[1]
            or self.width_slice[1] > input_shape[2]
            or self.features_slice[1] > input_shape[3]
        ):
            raise AccelerasInferenceError("The input shape of layer slice is not compatible with shapes")

    def enforce_encoding(self):
        if self.features_slice is not None and not self.input_scale_is_scalar(0):
            start, end, step = self.features_slice
            self.output_scale = self.input_scales[0][start:end:step]

        else:
            self.output_scale = self.input_scales[0]
        if self.features_slice is not None and len(tf.convert_to_tensor(self.input_zero_point).shape) != 0:
            start, end, step = self.features_slice
            self.output_zero_point = self.input_zero_point[start:end:step]

        else:
            self.output_zero_point = self.input_zero_point

    def backward_encoding(self):
        output_scale = self.output_scale
        if not self.input_scale_is_scalar(0):
            start, end, step = self.features_slice
            indices = tf.constant(list(range(self.input_shape[-1]))[start:end:step])
            input_scale = tf.zeros(self.input_shape[-1], dtype=output_scale.dtype)
            self.input_scale = tf.tensor_scatter_nd_add(input_scale, tf.expand_dims(indices, -1), self.output_scale)
        else:
            self.input_scale = self.output_scale
        self.input_zero_point = self.output_zero_point

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.callback(
            f"{self.full_name}/output_scale:0",
            f"{self.full_name}/input_scale:0",
            lambda x: x[self.features_slice[0] : self.features_slice[1] : self.features_slice[2]],
            callback_name="slice",
            outs_shape=(self.output_shape[-1],),
        )
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
