import collections
from itertools import repeat
from typing import Tuple

import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import get_tf_same_padding
from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding, QType


def qtype_to_range(qtype: QType) -> Tuple[int, int]:
    if qtype.signed:
        return -(2 ** (qtype.bits - 1)), 2 ** (qtype.bits - 1) - 1
    else:
        return 0, 2**qtype.bits - 1


def _reverse_repeat_tuple(t, n):
    r"""Reverse the order of `t` and repeat each element for `n` times.

    This can be used to translate padding arg used by Conv and Pooling modules
    to the ones used by `F.pad`.
    """
    return tuple(x for x in reversed(t) for _ in range(n))


def _ntuple(n, name="parse"):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return tuple(x)
        return tuple(repeat(x, n))

    parse.__name__ = name
    return parse


_pair = _ntuple(2, "_pair")


def parse_explicit_padding(padding, kernel_size, stride, dilation, spatial_shape):
    if padding not in {"same", "valid", "same_tensorflow"}:
        raise ValueError(f"Invalid padding string {padding!r}, should be one of {'valid', 'same', 'same_tensorflow'}")
    explicit_padding = [0, 0] * len(kernel_size)
    padding_mode = "zeros"
    if padding == "same" and all(s == 1 for s in stride):
        for d, k, i in zip(dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)):
            total_padding = d * (k - 1)
            left_pad = total_padding // 2
            explicit_padding[2 * i] = left_pad
            explicit_padding[2 * i + 1] = total_padding - left_pad
    elif padding in {"same", "same_tensorflow"}:
        # implement padding and strided conv similar to acceleras
        # NOTE: it requires the spatial shape to be passed, and dynamic shape will not be supported
        #       it can be calculated online, but we don't change the spatial shape during inference
        pad_end_h, pad_beg_h, pad_end_w, pad_beg_w = get_tf_same_padding(
            dilation,
            *spatial_shape,
            *kernel_size,
            *stride,
        )
        if padding == "same":
            explicit_padding = [pad_beg_w, pad_end_w, pad_beg_h, pad_end_h]
        elif padding == "same_tensorflow":
            explicit_padding = [pad_end_w, pad_beg_w, pad_end_h, pad_beg_h]
        # NOTE: torch's conv doesn't support custom padding, from both sides, so force constant padding (with zeros)
        padding_mode = "constant"
    if all(p == 0 for p in explicit_padding):
        padding = "valid"
        padding_mode = "zeros"
    return padding, padding_mode, explicit_padding


def init_encoding(
    scale_by_group: torch.Tensor,
    scale_repeats: int,
    zero_point_by_group: torch.Tensor,
    zero_point_repeats: int,
    equalization_vector: torch.Tensor = None,
) -> Encoding:
    scale_by_channel = torch.repeat_interleave(scale_by_group, scale_repeats)
    zero_point_by_channel = torch.repeat_interleave(zero_point_by_group, zero_point_repeats)
    factor_by_group = scale_by_group / scale_by_group.max()

    if equalization_vector is None:
        equalization_vector = torch.ones_like(scale_by_channel)
    scale_by_channel = scale_by_channel * equalization_vector

    return Encoding(
        scale_by_group=scale_by_group,
        scale_repeats=scale_repeats,
        zero_point_by_group=zero_point_by_group,
        zero_point_repeats=zero_point_repeats,
        equalization_vector=equalization_vector,
        factor_by_group=factor_by_group,
        scale_by_channel=scale_by_channel,
        zero_point_by_channel=zero_point_by_channel,
    )


class PaddingUtils:
    @staticmethod
    def pad_with_online_value(
        inp: torch.Tensor,
        pad_value: torch.Tensor,
        reversed_padding_repeated_twice: Tuple[int, int, int, int],
    ):
        pad_left, pad_right, pad_top, pad_bottom = reversed_padding_repeated_twice
        pad_value = pad_value.view(1, -1, 1, 1)
        if pad_left > 0 or pad_right > 0:
            left_pad = pad_value + torch.zeros(
                (inp.size(0), inp.size(1), inp.size(2), pad_left), device=inp.device, dtype=inp.dtype
            )
            right_pad = pad_value + torch.zeros(
                (inp.size(0), inp.size(1), inp.size(2), pad_right), device=inp.device, dtype=inp.dtype
            )
            inp = torch.cat([left_pad, inp, right_pad], dim=3)
        if pad_top > 0 or pad_bottom > 0:
            top_pad = pad_value + torch.zeros(
                (inp.size(0), inp.size(1), pad_top, inp.size(3)), device=inp.device, dtype=inp.dtype
            )
            bottom_pad = pad_value + torch.zeros(
                (inp.size(0), inp.size(1), pad_bottom, inp.size(3)), device=inp.device, dtype=inp.dtype
            )
            inp = torch.cat([top_pad, inp, bottom_pad], dim=2)
        return inp

    @staticmethod
    def pad_with_offline_value(
        inp: torch.Tensor,
        pad_value: torch.Tensor,
        reversed_padding_repeated_twice: Tuple[int, int, int, int],
    ):
        inp = nn.functional.pad(
            inp,
            reversed_padding_repeated_twice,
            mode="constant",
            value=pad_value,
        )
        return inp

    @classmethod
    def manual_padding(
        cls,
        inp: torch.Tensor,
        pad_value: torch.Tensor,
        reversed_padding_repeated_twice: Tuple[int, int, int, int],
    ):
        if isinstance(pad_value, torch.Tensor):
            return cls.pad_with_online_value(inp, pad_value, reversed_padding_repeated_twice)
        else:
            return cls.pad_with_offline_value(inp, pad_value, reversed_padding_repeated_twice)
