from typing import Tuple

import numpy as np
import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeatureMultiplierType
from hailo_model_optimization.saitama.framework.common.commom_funtions import CommonFunctions, Reshaping
from hailo_model_optimization.saitama.framework.common.saitama_definitions import DimsInfo
from hailo_model_optimization.saitama.framework.fused_modules.fused_base import FusedBase


class EWMult(nn.Module):
    def __init__(self, input_repeats: Tuple[DimsInfo, DimsInfo] = None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_repeats = input_repeats if input_repeats is not None else (DimsInfo(), DimsInfo())

    def forward(self, x, y):
        x = CommonFunctions.apply_repeat_interleave(x, self.input_repeats[0])
        y = CommonFunctions.apply_repeat_interleave(y, self.input_repeats[1])
        return x * y


class EWSub(nn.Module):
    def __init__(
        self,
        input_repeats: Tuple[DimsInfo, DimsInfo] = None,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.input_repeats = input_repeats if input_repeats is not None else (DimsInfo(), DimsInfo())

    def forward(self, x, y):
        x = CommonFunctions.apply_repeat_interleave(x, self.input_repeats[0])
        y = CommonFunctions.apply_repeat_interleave(y, self.input_repeats[1])
        return x - y


def apply_reduce_op(x, op, axes, groups):
    for axis in axes:
        x = x.unflatten(axis, (groups[axis - 1], x.shape[axis] // groups[axis - 1]))
        x = op(x, axis=axis + 1, keepdim=False)[0] if op == torch.max else op(x, axis=axis, keepdim=False)
    return x


class ReduceMax(nn.Module):
    def __init__(self, axes=None, groups=None):
        super().__init__()
        self.axis = axes if axes else [1]
        self.groups = groups if groups else [1, 1, 1]

    def forward(self, x):
        return apply_reduce_op(x, torch.max, self.axis, self.groups)


class ReduceSum(nn.Module):
    def __init__(self, axes=None, groups=None):
        super().__init__()
        self.axes = axes if axes else [1]
        self.groups = groups if groups else [1, 1, 1]
        self.kernel = nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        return torch.sign(self.kernel) * apply_reduce_op(x, torch.sum, self.axes, self.groups)


class ReduceMean(nn.Module):
    def __init__(self, axes=None, groups=None):
        super().__init__()
        self.axes = axes if axes else [1]
        self.groups = groups if groups else [1, 1, 1]
        self.kernel = nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        return torch.sign(self.kernel) * apply_reduce_op(x, torch.mean, self.axes, self.groups)


class FeatureMultiplier(nn.Module):
    weight: torch.Tensor

    def __init__(self, feature_multiplier_type, reduce_sum_groups):
        super().__init__()
        self.feature_multiplier_type = feature_multiplier_type
        self.reduce_sum_groups = reduce_sum_groups

    def forward(self, x):
        if self.feature_multiplier_type == FeatureMultiplierType.square:
            x = x**2
            if self.reduce_sum_groups != x.shape[1]:
                x = x.unflatten(1, (self.reduce_sum_groups, x.shape[1] // self.reduce_sum_groups))
                x = torch.sum(x**2, axis=2, keepdim=False)
            return x
        raise ValueError(f"Feature multiplier type {self.feature_multiplier_type} is not supported in saitama")


class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

    def extra_repr(self):
        return f"shape={self.shape}"


class GroupedSoftmax(nn.Module):
    def __init__(self, axis=-1, groups=1):
        super().__init__()
        self.axis = axis
        self.groups = groups

    def forward(self, x: torch.Tensor):
        x = x.view(x.shape[0], self.groups, x.shape[1] // self.groups, *x.shape[2:])
        result = []
        for i in range(self.groups):
            result.append(torch.softmax(x[:, i], dim=self.axis))
        return torch.cat(result, dim=1)

    def extra_repr(self):
        return f"axis={self.axis}, groups={self.groups}"


class ResizeNN(nn.Module):
    def __init__(self, h=1, w=1, f=1):
        super().__init__()
        self.resize_h = int(np.product(h))
        self.resize_w = int(np.product(w))
        self.resize_f = int(np.product(f))

    def forward(self, x: torch.Tensor):
        x = x.repeat_interleave(self.resize_f, dim=1)
        x = x.repeat_interleave(self.resize_h, dim=2)
        x = x.repeat_interleave(self.resize_w, dim=3)
        return x

    def extra_repr(self):
        return f"h={self.resize_h}, w={self.resize_w}, f={self.resize_f}"


class MatMul(nn.Module):
    def __init__(self, groups, windows: DimsInfo, input_tiles, transposed):
        super().__init__()
        self.groups = groups
        self.windows = windows
        self.input_tile = input_tiles
        self.transposed = transposed

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        x = Reshaping.reshape_matmul_input(x, self.groups, self.windows, False)
        y = Reshaping.reshape_matmul_input(y, self.groups, self.windows, self.transposed)
        x, y = Reshaping.tile_inputs(x, self.input_tile[0], y, self.input_tile[1])
        output = torch.matmul(x, y)
        output = Reshaping.reshape_matmul_output(output, self.groups, self.windows)
        return output


class EWAdd(nn.Module):
    def __init__(self, input_repeats: Tuple[DimsInfo, DimsInfo] = None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_repeats = input_repeats if input_repeats is not None else (DimsInfo(), DimsInfo())

    def forward(self, x, y):
        x = CommonFunctions.apply_repeat_interleave(x, self.input_repeats[0])
        y = CommonFunctions.apply_repeat_interleave(y, self.input_repeats[1])
        return x + y


class ConstInput(nn.Module):
    def __init__(self, shape, tile: DimsInfo, dtype=None, device=None):
        super().__init__()
        self.tile = tile
        self.value = nn.Parameter(torch.zeros(shape, dtype=dtype, device=device))

    def forward(self):
        x = self.value.repeat(self.tile.as_chw_tuple())
        return x.unsqueeze(0)


class FusedSoftmax(FusedBase):
    def __init__(self, axis):
        super().__init__()
        self.axis = axis
        self.softmax = nn.Softmax(dim=self.axis)

    def forward(self, x, **kwargs):
        return self.softmax(x)


class FusedLayerNormalization(FusedBase):
    def __init__(self, channels):
        super().__init__()
        self.layer_normalization = torch.nn.LayerNorm(channels)

    def forward(self, x, **kwargs):
        x = x.permute(0, 2, 3, 1)
        x = self.layer_normalization(x)
        return x.permute(0, 3, 1, 2)


class FusedRMSNormalization(FusedBase):
    def __init__(self, channels):
        super().__init__()
        self.layer_normalization = nn.RMSNorm([channels])

    def forward(self, x, **kwargs):
        x = x.permute(0, 2, 3, 1)
        x = self.layer_normalization(x)
        return x.permute(0, 3, 1, 2)


class FusedGroupedNormalization(FusedBase):
    def __init__(self, groups, channels):
        super().__init__()
        self.groups = groups
        self.group_normalization = nn.GroupNorm(self.groups, channels)

    def forward(self, x, **kwargs):
        return self.group_normalization(x)


class FusedConvAndAdd(nn.Conv2d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",  # TODO: refine this type
        device=None,
        dtype=None,
        *args,
        **kwargs,
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
            *args,
            **kwargs,
        )

    def forward(self, x, y):
        return super().forward(x) + y
