from typing import Tuple

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasDeconvOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.depth_to_space_op import DepthToSpaceOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError


class HailoDeconv(BaseHailoConv):
    """
    Represents `Deconv`

    Instead of up-sampling the input and then applying a convolution layer
    we transform the Deconv kernel in to use series of convolution layers, then
    we run it as a normal convolutional layer, after wards we interpolate the results
    to match the result of a normal deconv layer.


    Implementing deconv via conv + depth2space,
      where conv uses an appropriately "depthified" kernel and some pre/post padding nuance.

    Args:
            filters : Number of Deconv filters
        kernel_size: Shape of the kernel
        strides:  Stride of the deconv
        activation : Activation Function

    References:
       Is the deconvolution layer the same as a convolutional layer?
         https://arxiv.org/ftp/arxiv/papers/1609/1609.07009.pdf

    """

    # PrecisionMode and BiasMode are the same as BaseHailoConv
    # TODO: make sure it's correct
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.DECONV

    def __init__(
        self,
        name: str,
        filters,
        kernel_size: Tuple = (4, 4),
        strides: Tuple = (2, 2),
        activation="linear",
        groups=1,
        logger=None,
        **kwargs,
    ):
        # Kernels and Strides are symmetric
        self.stride = strides[0]
        self.strides = np.array(strides[0:2])
        self.k_size = kernel_size[0]
        self.k_sizes = np.array(kernel_size[0:2])
        self.rate = np.ceil(self.k_size / self.stride).astype("int")
        self.rates = np.ceil(self.k_sizes / self.strides).astype("int")
        self.original_filters = filters

        #  Setting the conv Atomic Op
        kernel_size = list(np.ceil(np.array(kernel_size) / np.array(strides)).astype("int"))
        filters = int(filters * strides[0] * strides[1])
        conv_op = ConvStrippedOp(
            f"{name}/conv_op",
            filters=filters,
            kernel_size=kernel_size,
            strides=(1, 1),
            groups=groups,
            padding="DECONV",
            is_depthwise=False,
            logger=logger,
            trainable=False,
        )

        self.stack_order = [(row, colum) for row in range(strides[0]) for colum in range(strides[1])]

        #  Setting DepthToSpaceOp
        if tuple(strides) == (1, 1):
            self.depth_to_space_op = PassthruOp(f"{name}/depth_to_space_op", logger=logger)
        else:
            self.depth_to_space_op = DepthToSpaceOp(
                f"{name}/depth_to_space_op",
                mode="dcr",
                block_sizes=(strides[0], strides[1]),
                groups=groups,
                logger=logger,
            )

        # padding = self.strides // self.rates
        # start_slice = padding
        end_slice = np.maximum(self.strides // 2, 1)
        start_slice = self.strides - end_slice
        # I don't fully understand why it works. required for stride [1, 1] with kernel 2
        # end_slice = (self.k_sizes % 2 == 1) * padding + (self.k_sizes % 2 == 0) * np.maximum(padding, 1)
        # if start_slice == 0 and end_slice == 0:
        #     self.slice_op = PassthruOp(f'{name}/slice_op', logger=logger)
        # else:
        # TODO: handle case of end slice zero
        height_slice = None if self.rates[0] % 2 == 1 else (start_slice[0], -end_slice[0], 1)
        width_slice = None if self.rates[1] % 2 == 1 else (start_slice[1], -end_slice[1], 1)
        self.slice_op = SliceOp(
            f"{name}/slice_op",
            height_slice=height_slice,
            width_slice=width_slice,
            logger=logger,
        )

        # This is not cool :(
        super().__init__(name, conv_op, activation=activation, logger=logger, **kwargs)

        #  Setting Bias Op
        self.bias_add_op = AddBiasDeconvOp(
            f"{name}/bias_add_op",
            strides[0] * strides[1],
            groups=groups,
            bias_initializer=None,
            trainable=False,
            is_correctable=False,
            logger=logger,
        )
        self._layer_flow = self._build_flow()

        self.encoding_const = False

    def _layer_dependent_hw_params_modifications(self, hw_params: dict):
        hw_params["conv_bias"] = hw_params["bias"].copy()
        hw_params["conv_kernel"] = hw_params["kernel"].copy()
        if "bias_q_int8_vec_a" in hw_params:
            bias_a = self._change_numeric_bias(hw_params["bias_q_int8_vec_a"]).astype(np.int8)
            hw_params["bias_q_int8_vec_a"] = bias_a
        if "bias_q_int8_vec_b" in hw_params:
            bias_b = self._change_numeric_bias(hw_params["bias_q_int8_vec_b"]).astype(np.int8)
            hw_params["bias_q_int8_vec_b"] = bias_b

        hw_params["bias"] = self._change_numeric_bias(hw_params["bias"]).astype(hw_params["bias"].dtype)
        hw_params["kernel"] = self._undepthify_kernel(hw_params["kernel"]).astype(hw_params["kernel"].dtype)
        return hw_params

    def get_numeric_kernel_np(self):
        kernel_depthified = self.conv_op.final_numeric_kernel.numpy()
        return self._undepthify_kernel(kernel_depthified)

    def get_bias_np(self):
        bias = self.bias_add_op.short_bias.numpy()
        return bias

    def _change_numeric_bias(self, bias):
        """
        HW wants bias channel wise, meaning every kernels does its dephepy
        Args:
            bias: Numeric bias from the op
        """
        # TODO: consider moving this logic to AddBiasDeconvOp...
        if self.k_sizes[0] != self.k_sizes[1]:  # _change_numeric is not relevant for asymertic deconv
            return bias
        legacy_stack_order = self._get_deconv_stack_order(self.k_size, self.stride)
        if legacy_stack_order is not None:
            new_order = [row * self.stride + col for row, col in legacy_stack_order]
            biass = np.vstack(np.split(bias, self.stride**2)).T
            mapping_matrix = np.zeros((self.stride**2, self.stride**2))
            order = np.zeros((self.stride**2, self.stride**2))
            rows = np.arange(self.stride**2)

            # This matrix takes the Reverse we apply on the kernel
            mapping_matrix[rows, rows[::-1]] = 1

            # This Matrix order then on the correct order given by decon Stack order
            order[np.array(new_order), rows] = 1
            total = biass @ mapping_matrix @ order

            bias_out = np.reshape(total, -1)
            return bias_out
        else:
            return bias

    def _export_weights(self):
        conv_weights = self.conv_op.export_weights()
        kernel = self._undepthify_kernel(conv_weights["kernel"])
        padding_const_value = conv_weights["padding_const_value"]
        bias = self.bias_add_op.short_bias.numpy()
        activation_params = self._export_activation()
        dict_params = {"kernel": kernel, "bias": bias, "padding_const_value": padding_const_value}
        dict_params.update(activation_params)
        return dict_params

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        """
        WIP.
        Implementing basic scalar case.

        NOTE: the create_hw_params() methods of atomic_ops usually create "candidates" for
              the actual numerization params, only finalizing "independent" params,
               to later be used in finalization as performed in "infer_encodings".

              The comments below try to carefully specify what and where is finalized.
        """
        if self.act_op.quantization_groups_num > 1:
            raise AccelerasImplementationError(
                f"For layer {self.full_name} we don't support qunatization with quantization groups yet",
            )

        # TODO to think about here - we generally wanted to only use scale&zp “candidates” and not
        #  limvals at this point (after they been consumed for all layers&ops in model.create_io_encoding_candidates)
        #
        self._enforce_output_encoding()
        pre_act_stats = self.get_preact_stats()[0]

        max_final_pre_acc_by_channel = np.maximum(np.abs(pre_act_stats.min), np.abs(pre_act_stats.max))
        self.depth_to_space_op.output_scale = max_final_pre_acc_by_channel
        self.depth_to_space_op.backward_encoding()
        max_final_accumulator_by_channel = self.depth_to_space_op.input_scales[0]
        kernel_scale_matrix_component = self.get_kernel_scale_matrix_component()
        self.conv_op.create_hw_params(
            max_final_accumulator_by_channel,
            weights_clipping,
            optimization_target,
            kernel_scale_matrix_component=kernel_scale_matrix_component,
            hw_shifts=hw_shifts,
        )
        self.bias_add_op.pre_acc_shift = self.conv_op.pre_acc_shift

        # From accumulator scale candidate, create the "ideal" output factor (*finalized*).
        pre_depth_to_space = self.conv_op.accumulator_scale_candidate
        encapsulated = [np.reshape(pre_depth_to_space, (1, 1, 1, -1))]
        fixed_accumulator = np.squeeze(np.max(self.depth_to_space_op.call_native(encapsulated), axis=(1, 2)))
        self.act_op.create_hw_params(fixed_accumulator, optimization_target)

        # This MOSTLY finalizes the "independent" params, so rest of job can be done by infer_encodings():
        # The "ideal" output factor leads to the "numeric" (M/E decomposed, @slope=1) output factor,
        # and then the *finalized* (aka, "mantissa-nudged") accumulator scale, followed by kernel scales (L&R),
        # and then all intemediate zero-points, and then final APU parameters.
        # An exception is bias (& elwa, see subclas..) decompositions
        self._create_hw_params_finalize()
        self._has_hw_params = True

    def get_kernel_scale_matrix_component(self):
        self.depth_to_space_op.output_scale = self.output_scale
        self.depth_to_space_op.backward_encoding()
        out_pass_to_kernel = self.depth_to_space_op.input_scales[0]
        return self.conv_op.calc_kernel_scale(self.conv_op.input_scales, out_pass_to_kernel)

    @classmethod
    def get_default_params(cls):
        defaults = {
            "strides": [1, 2, 2, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "DECONV",
            "elementwise_add": False,
            "groups": 1,
            "activation": "linear",
        }
        return defaults

    def get_kernel_np(self):
        kernel_depthified = self.conv_op.kernel.numpy()
        return self._undepthify_kernel(kernel_depthified)

    def _undepthify_kernel(self, kernel_depthified, kernel_is_transposed=True):
        """
        This method interleave small Kernels into a full Kernel.

                                              [[ a, b, a, b],
         [a , a]  [b, b]  [c, c]  [d, d]      [ c, d, c, d],
         [a , a], [b, b], [c, c], [d, d]  =>  [ a, b, a, b],
                                              [ c, d, c, d]]

        Args:
            kernel_depthified: Kernel on the hardware
            kernel_is_transposed: If the Kernel is Transpose

        Returns:
            Numpy array that has the Kernel values

        """
        # Create a placeholder for the output Kernel
        inp_shape = kernel_depthified.shape
        output_shape = (
            inp_shape[0] * self.strides[0],
            inp_shape[1] * self.strides[1],
            inp_shape[2],
            inp_shape[3] // (self.strides[0] * self.strides[1]),
        )
        kernel_undepthify = np.zeros(shape=output_shape, dtype=kernel_depthified.dtype)

        groups = max(self.conv_op.groups, 1)
        group_size_in = kernel_depthified.shape[-1] // groups
        group_size_out = kernel_undepthify.shape[-1] // groups
        for g in range(groups):
            # Rebuilding the Kernel
            splits = np.split(
                kernel_depthified[..., g * group_size_in : (g + 1) * group_size_in],
                indices_or_sections=(self.strides[0] * self.strides[1]),
                axis=3,
            )
            for split, (col, row) in zip(splits, reversed(self.stack_order)):
                kernel_undepthify[
                    col :: self.strides[0],
                    row :: self.strides[1],
                    ...,
                    g * group_size_out : (g + 1) * group_size_out,
                ] = split

        # Post process
        if not kernel_is_transposed:
            kernel_undepthify = kernel_undepthify.transpose((0, 1, 3, 2))
        kernel_undepthify = kernel_undepthify[::-1, ::-1]
        return kernel_undepthify

    def _depthify_kernel(self, kernel, kernel_is_transposed: bool = True):
        """
        This method samples with strides a full Kernel to get
        smaller kernels.

        [[ a, b, a, b],
         [ c, d, c, d],   =>  [a , a]  [b, b]  [c, c]  [d, d]
         [ a, b, a, b],       [a , a], [b, b], [c, c], [d, d]
         [ c, d, c, d]

        Args:
            kernel: Original Kernel
            kernel_is_transposed: If the Kernel is Transpose

        Returns:
            Numpy array that has the Kernel values

        """
        if not kernel_is_transposed:
            kernel = kernel.transpose(0, 1, 3, 2)

        pad_total = np.ceil(kernel.shape[:2] / self.strides).astype(int) * self.strides - kernel.shape[:2]
        pad_end = pad_total // 2
        pad_start = pad_total - pad_end
        kernel_pad = np.pad(kernel, ((pad_start[0], pad_end[0]), (pad_start[1], pad_end[1]), (0, 0), (0, 0)))
        k_inv = kernel_pad[::-1, ::-1]

        kernel_groups = []
        groups = max(self.conv_op.groups, 1)
        group_size = kernel.shape[-1] // groups

        for i in range(groups):
            kernels = [
                k_inv[row :: self.strides[0], column :: self.strides[1], ..., group_size * i : group_size * (i + 1)]
                for row, column in self.stack_order
            ]
            padded_kernels = []
            for k in kernels:
                padded_kernels.insert(0, k)
            kernel_group_depthified = np.concatenate(padded_kernels, axis=-1)
            kernel_groups.append(kernel_group_depthified)
        return np.concatenate(kernel_groups, axis=-1)

    def import_native_kernel(self, kernel, kernel_is_transposed=True):
        """
        Loads the Kernel using
        Args:
          kernel_is_transposed:  SDK's NPZ hold the kernel in transposed form by default,
                                 use False if kernel is given exactly as used by conv2d_transpose.
        """
        kernel_depthified = self._depthify_kernel(kernel, kernel_is_transposed)
        self.conv_op.import_weights(tf.cast(kernel_depthified, tf.float32))

    def _get_activation_input_shape(self, input_shape):
        _acc_shape = self._get_bias_input_shape(input_shape)
        block_size_h = self.depth_to_space_op.block_sizes[0]
        block_size_w = self.depth_to_space_op.block_sizes[1]
        _acc_shape[-1] /= block_size_h * block_size_w
        _acc_shape[-2] *= block_size_w
        _acc_shape[-3] *= block_size_h
        return _acc_shape

    def _accumulator_scale_from_apu(self):
        """
        Accumulator scale is fully defined by output and APU params,
        we resolve it in Activation class and use for all earlier op scales. specially for deconv
        """
        self.act_op.get_accumulator_scale()
        if self.act_op.input_scales[0].shape == ():
            self.acc_scale = tf.expand_dims(self.act_op.input_scales[0], 0)
        else:
            self.acc_scale = self.act_op.input_scales[0]

        self.depth_to_space_op.output_scale = self.acc_scale
        self.depth_to_space_op.backward_encoding()
        acc_scale_interleave = self.depth_to_space_op.input_scales[0]
        self.conv_op.output_scale = acc_scale_interleave
        self.bias_add_op.input_scales[0] = acc_scale_interleave
        self.bias_add_op.output_scale = acc_scale_interleave

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        """
        Creates a Composite Op from a hn_element
        Args:
            lname: Layer name
            hn_element: HN element that represent the deconv
            logger: Logger for the layer

        Returns
            Deconv Composite op

        """
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))

        strides = tuple(params["strides"][1:3])
        kshape = params["kernel_shape"]
        kernel_size = kshape[0:2]
        cls._validate_elwa(params["elementwise_add"])
        layer = cls(
            name=lname,
            filters=kshape[-1],
            kernel_size=kernel_size,
            strides=strides,
            activation=params["activation"],
            groups=params["groups"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @staticmethod
    def _get_deconv_stack_order(kernel_shape, stride):
        if (kernel_shape, stride) in [(4, 2), (3, 2), (2, 2)]:
            stack_order = [(0, 0), (0, 1), (1, 0), (1, 1)]

        elif (kernel_shape, stride) == (4, 4):
            stack_order = [
                (0, 0),
                (0, 2),
                (2, 0),
                (2, 2),
                (0, 1),
                (0, 3),
                (2, 1),
                (2, 3),
                (1, 0),
                (1, 2),
                (3, 0),
                (3, 2),
                (1, 1),
                (1, 3),
                (3, 1),
                (3, 3),
            ]

        elif (kernel_shape, stride) == (8, 4):
            stack_order = [
                (1, 1),
                (1, 3),
                (3, 1),
                (3, 3),
                (1, 0),
                (1, 2),
                (3, 0),
                (3, 2),
                (0, 1),
                (0, 3),
                (2, 1),
                (2, 3),
                (0, 0),
                (0, 2),
                (2, 0),
                (2, 2),
            ]

        elif (kernel_shape, stride) == (16, 8):
            stack_order = [
                (3, 3),
                (3, 7),
                (7, 3),
                (7, 7),
                (3, 1),
                (3, 5),
                (7, 1),
                (7, 5),
                (1, 3),
                (1, 7),
                (5, 3),
                (5, 7),
                (1, 1),
                (1, 5),
                (5, 1),
                (5, 5),
                (3, 2),
                (3, 6),
                (7, 2),
                (7, 6),
                (3, 0),
                (3, 4),
                (7, 0),
                (7, 4),
                (1, 2),
                (1, 6),
                (5, 2),
                (5, 6),
                (1, 0),
                (1, 4),
                (5, 0),
                (5, 4),
                (2, 3),
                (2, 7),
                (6, 3),
                (6, 7),
                (2, 1),
                (2, 5),
                (6, 1),
                (6, 5),
                (0, 3),
                (0, 7),
                (4, 3),
                (4, 7),
                (0, 1),
                (0, 5),
                (4, 1),
                (4, 5),
                (2, 2),
                (2, 6),
                (6, 2),
                (6, 6),
                (2, 0),
                (2, 4),
                (6, 0),
                (6, 4),
                (0, 2),
                (0, 6),
                (4, 2),
                (4, 6),
                (0, 0),
                (0, 4),
                (4, 0),
                (4, 4),
            ]
        elif (kernel_shape, stride) in [(1, 1), (2, 1), (3, 1)]:
            stack_order = [(0, 0)]
        else:
            return None
        return stack_order

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.conv_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.depth_to_space_op)
        layer_flow.add_node(self.slice_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.conv_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.conv_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.depth_to_space_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.depth_to_space_op, self.slice_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.slice_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def enforce_internal_encoding(self, training=False, **kwargs):
        retval = super().enforce_internal_encoding(training=training, **kwargs)
        self.slice_op.output_scale = self.act_op.input_scales[0]
        self.slice_op.output_zero_point = self.act_op.input_zero_points[0]
        self.slice_op.input_scales[0] = self.slice_op.output_scale
        self.slice_op.input_zero_points[0] = self.slice_op.output_zero_point
        return retval

    @property
    def bit_exact_supported(self):
        if self.groups != 1:
            return False
        return super().bit_exact_supported

    def _supported_quantization_groups_hw(self, quantization_groups, arch):
        return False
