from typing import List

import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DEFAULT_CONCAT_AXIS, ConcatAxis
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasNumerizationError,
    AccelerasValueError,
)


class ConcatOp(BaseNonArithmeticAtomicOp):
    """
    Produces an AtomicOp that contains one concat operation.
    The purpose of this Op is to simulate the operation of the concat layer.

    Args:
        num_inputs: number of inputs
        axis: defaults to DEFAULT_CONCAT_AXIS - an enum describing the axis on which the layer concatenates.
        **kwargs:

    """

    num_outputs = 1

    def __init__(
        self,
        name,
        concat_elements,
        vector_zp=False,
        axis: ConcatAxis = DEFAULT_CONCAT_AXIS,
        group_sizes=None,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        self._concat_elements = concat_elements
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._axis = axis
        self.vector_zp = vector_zp
        self._group_sizes = group_sizes
        if self._group_sizes is not None and self._axis != ConcatAxis.features and self._axis != ConcatAxis.spatial_h:
            raise ValueError("concat with group sizes is only supported on features axis")

    @property
    def spatial_h_concat(self):
        return self._axis == ConcatAxis.spatial_h

    @property
    def spatial_w_concat(self):
        return self._axis == ConcatAxis.spatial_w

    @property
    def num_inputs(self) -> int:
        return self._concat_elements

    def _build(self, input_shapes):
        if self._group_sizes is not None:
            group_sizes_sum = sum(self._group_sizes)
            for input_shape in input_shapes:
                dim_to_check = input_shape[3] if self._axis == ConcatAxis.features else input_shape[1]
                if dim_to_check % group_sizes_sum != 0:
                    raise ValueError(
                        f"concat with group size have {dim_to_check} input size which cannot have "
                        f"groups of sizes {self._group_sizes}",
                    )

    def _get_concat_axis(self):
        if self.spatial_w_concat:
            axis = 2
        elif self.spatial_h_concat:
            axis = 1
        else:
            axis = 3
        return axis

    def _compute_output_shape(self, input_shapes):
        if self.num_inputs == 1:
            return input_shapes
        concat_axis = self._get_concat_axis()
        shape = []
        for axis, val in enumerate(input_shapes[0]):
            if axis == concat_axis:
                shape.append(sum([input_shape[axis] for input_shape in input_shapes]))
            else:
                shape.append(val)
        return shape

    def call_native(self, inputs, **kwargs):
        axis = self._get_concat_axis()
        if self._group_sizes is None or axis != 3:
            concat_inputs_by_group = inputs
        else:
            # For group concat we slice each concat input to group in the given ratio and concat first all the first
            # slices, second slices, and so on.
            # For example, for concat with two inputs with 6 inputs for the first input and 3 for the second and group
            # sizes of [1,2] the concat result will be:
            # tf.concat(inp1[:,:,:,2], inp2[:,:,:,:1], inp1[:,:,:,2:], inp2[:,:,:,1:], axis=3)
            if axis != 3:
                raise AccelerasValueError(
                    f"In composite op {self.full_name}: concat_op does not support asymmetric groups and spatial axis.",
                )
            concat_inputs_by_group = []
            total_size = sum(self._group_sizes)
            for i, group_size in enumerate(self._group_sizes):
                prev_group_sizes_sum = sum(self._group_sizes[:i])
                for inp in inputs:
                    input_features = inp.shape[3]
                    features_for_group_of_size_1 = input_features // total_size
                    features_for_previous_groups = features_for_group_of_size_1 * prev_group_sizes_sum
                    features_for_curr_group = features_for_group_of_size_1 * group_size
                    start = features_for_previous_groups
                    end = start + features_for_curr_group
                    concat_inputs_by_group.append(inp[:, :, :, start:end])

        op = tf.concat(concat_inputs_by_group, axis)
        return op

    def enforce_encoding(self, forward=True, *args, **kwargs):
        if self.spatial_w_concat:
            axis = 2
        elif self.spatial_h_concat:
            axis = 1
        else:
            axis = 3

        if self._group_sizes is not None and axis == 3:
            self.enforce_encoding_for_asymmetric_groups(forward=forward)
            return

        if forward:
            if self.spatial_h_concat or self.spatial_w_concat or self.input_scale_is_scalar(0):
                self.output_scale = self.input_scales[0]
            else:  # vector scales
                self.output_scale = tf.cast(tf.concat(self.input_scales, axis=0), tf.float32)
            try:
                # TODO - will need to be removed in the future
                if self.vector_zp and tf.reshape(self.input_scales[0], [-1, 1]).shape[0] != 1:
                    vector_zp = tf.concat(
                        [
                            tf.tile([zp], tf.shape(scale))
                            for scale, zp in zip(self.input_scales, self.input_zero_points)
                        ],
                        axis=0,
                    )
                    self.output_zero_point = vector_zp
                else:
                    with tf.control_dependencies(
                        [
                            tf.debugging.Assert(
                                tf.reduce_all(tf.equal(self.input_zero_points, self.input_zero_points[0])),
                                self.input_zero_points,
                            ),
                        ],
                    ):
                        self.output_zero_point = self.input_zero_points[0]
            except Exception:
                raise AccelerasNumerizationError(f"Input zero points to concat op {self.full_name} are not all equal.")
        elif self.spatial_h_concat or self.spatial_w_concat or self.output_scale_is_scalar(0):
            for index in range(len(self.input_scales)):
                self.input_scales[index] = self.output_scale
        else:  # vector scales
            start_index = 0
            for index in range(len(self.input_scales)):
                out_f = self.input_shapes[index][-1]
                self.input_scales[index] = self.output_scale[start_index : start_index + out_f]
                start_index += out_f

    def enforce_encoding_for_asymmetric_groups(self, forward=True, *args, **kwargs):
        """
            For forward=True, assigns the relative chunk in the input_scales to the output_scale. For
            backward (forward=False) performs the inverse operation (that is, from output_scale back to the
            respective indices in the input_scales).

            For example: let's assume 2 groups, _group_sizes = [1,2], of shape 15 and 9 respectively.
            If each of the inputs are set as a list of increasing numbers ([0, 1, ..., 14] and
            [0, 1, ..., 8]) then the forward output is,

                output_scale = [ 0.,  1.,  2.,  3.,  4.,  0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10.,
                    11., 12., 13., 14., 3., 4., 5.,  6.,  7.,  8.]

            The backward (forward=False) is the inverse of this operation, from the output_scale back to
            input_scales.

        Args:
            forward (bool, optional): Defaults to True.

        Raises:
            AccelerasValueError: for asymmetric groups, only the concatenation along the feature axis is supported.
            AccelerasImplementationError: vector_zp != False and asymmetric groups is not supported.

        """
        if self.spatial_h_concat or self.spatial_w_concat:
            raise AccelerasValueError(
                f"In composite op {self.full_name}: concat_op does not support asymmetric groups and spatial axis.",
            )
        if self.vector_zp:
            raise AccelerasImplementationError(
                f"In composite op {self.full_name}: concat_op does not support asymmetric groups with vector_zp.",
            )

        if forward:
            # output_scale:
            output_scales = self.concat_by_groups(self.input_scales, self._group_sizes)
            self.output_scale = tf.cast(output_scales, self.FLOAT_TYPE_TF)

            # Support For different zero points across groups
            zero_points = []

            for index, shape in enumerate(self.input_shapes):
                if len(tf.convert_to_tensor(self.input_zero_points[index]).shape) == 0:
                    zero_point = tf.cast(tf.repeat(self.input_zero_points[index], shape[-1]), self.FLOAT_TYPE_TF)
                else:
                    zero_point = self.input_zero_points[index]
                zero_points.append(zero_point)
            zero_out = self.concat_by_groups(zero_points, self._group_sizes)
            self.output_zero_point = tf.cast(zero_out, self.FLOAT_TYPE_TF)

        else:
            ...

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # Compute output_scale
        if self.spatial_h_concat or self.spatial_w_concat:
            for i in range(self.num_inputs):
                enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:{i}")
        else:
            enc.concat(
                f"{self.full_name}/output_scale:0",
                *(f"{self.full_name}/input_scale:{i}" for i in range(self.num_inputs)),
                group_sizes=self._group_sizes,
            )

        # Compute output_zero_point TODO: generalize the case for vector zp
        for i in range(self.num_inputs):
            enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:{i}")

    @staticmethod
    def concat_by_groups(vectors: List[tf.Tensor], factors: List[int]):
        def create_slices(factors: List[int], vector_size: int):
            denominator = float(sum(factors))
            values = [int(round(val * vector_size / denominator)) for val in factors]
            ind_0 = 0
            slices = []
            for ind in values:
                ind_1 = ind_0 + ind
                slices.append((ind_0, ind + ind_0))
                ind_0 = ind_1
            return slices

        slices = [create_slices(factors, len(vec)) for vec in vectors]
        values = []
        for group_index, _ in enumerate(factors):
            for vector, slice in zip(vectors, slices):
                start, end = slice[group_index]
                values.append(vector[start:end])

        # Convert the tensor array to a regular tensor
        return tf.concat(values, axis=0)
