from typing import Union

import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import SpaceToDepthType


class SpaceToDepthOp(BaseNonArithmeticAtomicOp):
    """
    Rearranges blocks of spatial data, into depth.
    More specifically, this op outputs a copy of the input tensor where values from the height and width dimensions
    are moved to the depth dimension.
    he attr block_sizes indicates how the data is moved from the height (block_sizes[0])
    and width dimensions (block_sizes[1]) to the depth dimension.

    Non-overlapping blocks of size block_sizes[0] * block_sizes[1] are rearranged into depth at each location.
    The depth of the output tensor is block_sizes[0] * block_sizes[1] * input_depth.
    The Y, X coordinates within each block of the input become the high order component of the output channel index.
    The input tensor's height and width must be divisible by block_sizes[0]  and  block_sizes[1] respectively .

    For example, an input with shape [1, 6, 6, 20] ,block_size = (2,2)  will output [1, 3, 3, 80]
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        block_sizes: tuple = (2, 2),
        space_to_depth_type: Union[str, SpaceToDepthType] = "classic_dcr",
        spatial_flatten_output: bool = False,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        """
        Args:
            block_sizes: block_sizes indicated how to rearrange the spatial data, into depth.
            space_to_depth_type: #TODO check what every type does

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.block_sizes = block_sizes
        self.space_to_depth_type = space_to_depth_type
        self.spatial_flatten_output = spatial_flatten_output

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]

        block_size_h = self.block_sizes[0]
        block_size_w = self.block_sizes[1]
        _, height, width, channels = inp.shape.as_list()
        out_h = int(height / block_size_h)
        out_w = int(width / block_size_w)
        out_c = int(channels * block_size_h * block_size_w)
        if self.space_to_depth_type == SpaceToDepthType.SERIAL:
            d2s_reshaped = tf.reshape(
                inp,
                (-1, out_h, block_size_w * out_w, block_size_h, channels),
                name="reshape_input",
            )
            d2s_splits = tf.split(d2s_reshaped, channels, axis=4, name="channels_splits")
            d2s_splits_stack = []
            for split in d2s_splits:
                split_slice = tf.reshape(
                    split[:, :, ::out_w, :, :],
                    (-1, out_h, 1, block_size_h * block_size_w),
                    name="reshape_slice",
                )
                for i in range(1, out_w):
                    split_slice = tf.concat(
                        (
                            split_slice,
                            tf.reshape(split[:, :, i::out_w, :, :], (-1, out_h, 1, block_size_h * block_size_w)),
                        ),
                        axis=2,
                    )
                d2s_splits_stack.append(split_slice)
            d2s_stack = tf.concat(d2s_splits_stack, axis=3, name="channels_stack")
            if self.spatial_flatten_output:
                op = tf.reshape(d2s_stack, (-1, 1, out_h * out_w, out_c))
            else:
                op = tf.reshape(d2s_stack, (-1, out_h, out_w, out_c))
        elif self.space_to_depth_type == SpaceToDepthType.CLASSIC_DCR:
            op = tf.reshape(inp, (-1, out_h, block_size_h, out_w, block_size_w, channels))
            op = tf.transpose(a=op, perm=(0, 1, 3, 2, 4, 5))
            op = tf.reshape(op, (-1, out_h, out_w, out_c))
        elif self.space_to_depth_type == SpaceToDepthType.CLASSIC_CRD:
            op = tf.reshape(inp, (-1, out_h, block_size_h, out_w, block_size_w, channels))
            op = tf.transpose(a=op, perm=(0, 1, 3, 5, 2, 4))
            op = tf.reshape(op, (-1, out_h, out_w, out_c))
        elif self.space_to_depth_type == SpaceToDepthType.FOCUS:
            op = tf.reshape(inp, (-1, out_h, block_size_h, out_w, block_size_w, channels))
            op = tf.transpose(a=op, perm=(0, 1, 3, 4, 2, 5))
            op = tf.reshape(op, (-1, out_h, out_w, out_c))
        else:
            op = tf.concat(
                [
                    inp[:, ::block_size_h, ::block_size_w, :],
                    inp[:, 1::block_size_h, ::block_size_w, :],
                    inp[:, ::block_size_h, 1::block_size_w, :],
                    inp[:, 1::block_size_h, 1::block_size_w, :],
                ],
                axis=3,
            )
        return op

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def _compute_output_shape(self, input_shape):
        block_size_h = self.block_sizes[0]
        block_size_w = self.block_sizes[1]
        _, height, width, channels = list(input_shape)
        out_h = int(height / block_size_h)
        out_w = int(width / block_size_w)
        out_c = int(channels * block_size_h * block_size_w)
        if self.space_to_depth_type in [
            SpaceToDepthType.SERIAL,
            SpaceToDepthType.CLASSIC_DCR,
            SpaceToDepthType.CLASSIC_CRD,
            SpaceToDepthType.FOCUS,
        ]:
            if self.spatial_flatten_output:
                return [input_shape[0], 1, out_h * out_w, out_c]
            else:
                return [input_shape[0], out_h, out_w, out_c]
        return super()._compute_output_shape(input_shape)

    def enforce_encoding(self):
        if len(self.input_scales[0].shape) == 0:
            self.output_scale = self.input_scales[0]
        else:
            repeat_num = self.block_sizes[0] * self.block_sizes[1]
            repeats = repeat_num * tf.ones(len(self.input_scales[0]), dtype=tf.int32)
            self.output_scale = tf.repeat(self.input_scales[0], repeats, axis=0)

        self.output_zero_point = self.input_zero_points[0]

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # Compute scales
        repeat_num = self.block_sizes[0] * self.block_sizes[1]
        repeats = repeat_num * tf.ones(self.input_shape[-1], dtype=tf.int32)
        enc.callback(
            f"{self.full_name}/output_scale:0",
            f"{self.full_name}/input_scale:0",
            tf.repeat,
            callback_name="tf.repeat",
            outs_shape=(self.output_shape[-1],),
            repeats=repeats,
            axis=0,
        )

        # Compute output_zero_point
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
