from typing import List, Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DepthToSpaceType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError


class DepthToSpaceOp(BaseNonArithmeticAtomicOp):
    """
    Rearranges data from depth into blocks of spatial data. This is the reverse transformation of SpaceToDepth.
    More specifically, this op outputs a copy of the input tensor where values from the depth dimension are moved in
    spatial blocks to the height and width dimensions.
    The attr block_sizes indicates how the data is moved from the depth dimension to the height (block_sizes[0])
    and width dimensions (block_sizes[1]).

    Chunks of data of size block_sizes[0] * block_sizes[1] from depth are rearranged into
    non-overlapping blocks of size  block_sizes[0] x block_sizes[1]
    The width the output tensor is input_depth * block_sizes[1], whereas the height is input_height * block_sizes[0].
    The Y, X coordinates within each block of the output image are determined by the high order component of the input
    channel index.
    The depth of the input tensor must be divisible by block_sizes[0] * block_sizes[1].

    For example, an input with shape  [1, 1, 1, 4] ,block_size = (2,2) will output [1, 2, 2, 1]
    For example, an input with shape  [1, 1, 1, 4] ,block_size = (1,2) will output [1, 1, 2, 2]

    Args:
        block_sizes: block_sizes indicated how to rearrange the data from depth into blocks of spatial data
        custom_order (Optional): List with the order than the interleaved should be done.

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        mode: Union[DepthToSpaceType, str],
        block_sizes=(2, 2),
        custom_order: List[int] = None,
        groups=1,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._mode = DepthToSpaceType(mode)
        self.custom_order = custom_order
        self.block_sizes = block_sizes
        self.groups = groups
        self.concat = tf.keras.layers.Concatenate(axis=-1)

    def call_native(self, inputs, **kwargs):
        # symmetric
        inp = inputs[0]

        groups = max(self.groups, 1)
        group_size = inp.shape[-1] // groups
        result = []
        for g in range(groups):
            group_op = self._depth_to_space_action(inp[..., g * group_size : (g + 1) * group_size])
            result.append(group_op)
        return tf.concat(result, axis=-1)

    def _depth_to_space_action(self, inputs):
        """
        The operation is based on depth to space implementation in Pytorch
        https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-13
        """
        block_size_h = self.block_sizes[0]
        block_size_w = self.block_sizes[1]
        height, width, channels = inputs.shape[1:]
        out_h = int(height * block_size_h)
        out_w = int(width * block_size_w)
        out_c = int(channels // (block_size_h * block_size_w))

        if self._mode == DepthToSpaceType.dcr:
            op = tf.reshape(inputs, (-1, height, width, block_size_h, block_size_w, out_c))
            op = tf.transpose(a=op, perm=(0, 1, 3, 2, 4, 5))
        elif self._mode == DepthToSpaceType.crd:
            op = tf.reshape(inputs, (-1, height, width, out_c, block_size_h, block_size_w))
            op = tf.transpose(a=op, perm=(0, 1, 4, 2, 5, 3))
        else:
            raise ValueError(f"Unexpected mode {self._mode} in op {self.full_name}")
        # After the transpose the shape is (-1, height, block_size_h, width, block_size_w * out_c)
        # We transpose the inputs so that block_size_h data from the depth is moved near the height spatial data.
        # Then because reshape does not change the order of the data we get that the reshape of
        # (-1, height, block_size_h, width, block_size_w * out_c) ==>>
        # (-1, height * block_size_h, width* block_size_w, out_c)
        op = tf.reshape(op, (-1, out_h, out_w, out_c))
        return op

    def _compute_output_shape(self, input_shape):
        block_size_h = self.block_sizes[0]
        block_size_w = self.block_sizes[1]
        bs, height, width, channels = input_shape
        out_h = int(height * block_size_h)
        out_w = int(width * block_size_w)
        out_c = int(channels // (block_size_h * block_size_w))
        return [bs, out_h, out_w, out_c]

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs)

    def enforce_encoding(self):
        if len(self.input_scales[0].shape) > 0:
            # verify that all scales are the same
            if not np.all(self.input_scales[0] == self.input_scales[0][0]):
                raise AccelerasImplementationError(
                    "Cannot run enforce encoding when the scales of depth to space is a vector where not all coordinates are the same",
                )
            factor = self.block_sizes[0] * self.block_sizes[1]
            out_c_size = int(len(self.input_scales[0]) / factor)
            self.output_scale = self.input_scales[0][0] * tf.ones(out_c_size)
        else:
            self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]

    def backward_encoding(self):
        block_sizes = tf.expand_dims(tf.ones(shape=self.block_sizes), -1)  # making it a column
        output_scale_to_infer = tf.expand_dims(self.output_scale, 0)  # making it a row - this is
        reshaped_output_scales = tf.matmul(block_sizes, output_scale_to_infer)
        inputs_scale = tf.reshape(reshaped_output_scales, [self.input_shape[-1]])
        self.input_scales[0] = inputs_scale

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.get_encoding(f"{self.full_name}/input_scale:0").scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # Compute scales
        enc.callback(
            f"{self.full_name}/output_scale:0",
            f"{self.full_name}/input_scale:0",
            lambda x: tf.ones((self.output_shape[-1],)) * x[0],
            outs_shape=(self.output_shape[-1],),
        )
        enc.callback(
            f"{self.full_name}/input_scale:0",
            f"{self.full_name}/output_scale:0",
            lambda x: tf.ones((self.input_shape[-1],)) * x[0],
            outs_shape=(self.input_shape[-1],),
        )

        # Compute output_zero_point
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
