import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import CacheOpMode


class CacheOp(BaseAtomicOp):
    """
    Describe a cache operation, useful to cache the output or use it as input.
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, cache_id: str, mode, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._cache_id = cache_id
        self._mode = mode

    def _compute_output_shape(self, input_shape):
        return input_shape

    def create_weight_quant_element(self, **kwargs):
        pass

    def call_hw_sim(self, inputs, cache_config, **kwargs):
        return self.call_native(inputs, cache_config, **kwargs)

    def call_native(self, inputs, cache_config, **kwargs):
        self._cache_config = cache_config
        output = tf.numpy_function(self.numpy_call, [inputs], np.float32)

        # defines the output shape based on the mode
        w = (
            self._cache_config.cache_size - self._cache_config.prefill_size
            if self._mode == CacheOpMode.READ
            else self._cache_config.prefill_size
        )
        return tf.ensure_shape(output, [inputs[0].shape[0], 1, w, inputs[0].shape[-1]])

    def numpy_call(self, inputs, **kwargs):
        cache_id = self._cache_id
        cache_size = self._cache_config.cache_size

        if self._mode == CacheOpMode.READ:
            if cache_id not in self._cache_config.cache_mapping:
                # initializes the cache id value
                self._cache_config.cache_mapping.update(
                    {
                        cache_id: np.zeros(
                            shape=[1, 1, cache_size, inputs[0].shape[-1]],
                            dtype=np.float32,
                        ),
                    },
                )
                self._cache_config.write_pointer_mapping[cache_id] = 0
                return self._cache_config.cache_mapping[cache_id][:, :, self._cache_config.prefill_size :, :]

            # returns the last #cache_size - #prefill_size newest values
            start_read_pointer = (
                self._cache_config.write_pointer_mapping[cache_id] + self._cache_config.prefill_size
            ) % cache_size
            end_read_pointer = self._cache_config.write_pointer_mapping[cache_id]

            return self.read_from_cache(
                start_read_pointer,
                end_read_pointer,
            )

        # writes the values to the cache
        start_write_index = self._cache_config.write_pointer_mapping[cache_id]
        end_write_index = (start_write_index + self._cache_config.prefill_size) % cache_size
        return self.write_to_cache(start_write_index, end_write_index, inputs)

    def read_from_cache(self, start_index, end_index):
        if start_index < end_index:
            # this case is when the read block is continuous in cache and the start *reading* pointer
            # is smaller than the end *reading* pointer
            # like in the schema, the dashes are the values to be read
            # [       #############     ]
            return self._cache_config.cache_mapping[self._cache_id][:, :, start_index:end_index, :]

        # this case is when the read block is not continuous in cache and the end *reading* pointer
        # is smaller than the start *reading* pointer
        # like in the schema, the dashes are the values to be read
        # [#######         #########]
        part1 = self._cache_config.cache_mapping[self._cache_id][:, :, start_index:, :]
        part2 = self._cache_config.cache_mapping[self._cache_id][:, :, :end_index, :]
        return np.concatenate((part1, part2), axis=2)

    def write_to_cache(self, start_index, end_index, inputs):
        if start_index < end_index:
            self._cache_config.cache_mapping[self._cache_id][:, :, start_index:end_index, :] = inputs[0]
        else:
            # not continuous write end_index < start_index
            # the writing has to be split into two parts
            # fist part is from start_index to the end of the cache
            # second part is from the beginning of the cache to end_index
            # the size of the first part is cache_size - start_index
            part1 = inputs[0][:, :, : self._cache_config.cache_size - start_index, :]
            part2 = inputs[0][:, :, self._cache_config.cache_size - start_index :, :]
            self._cache_config.cache_mapping[self._cache_id][:, :, start_index:, :] = part1
            self._cache_config.cache_mapping[self._cache_id][:, :, :end_index, :] = part2

        self._cache_config.write_pointer_mapping[self._cache_id] = end_index

        return inputs[0]  # return the inputs to keep the computational graph connected

    def export_weights(self):
        return {}

    def create_hw_params(self):
        pass

    def enforce_encoding(self, *args, **kwargs):
        self.forward_encoding()

    def forward_encoding(self):
        self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]

    def backward_encoding(self):
        self.input_scale = self.output_scale
        self.input_zero_point = self.output_zero_point

    @property
    def bit_exact_supported(self) -> bool:
        return True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
