from typing import Tuple

import torch
from einops import rearrange, repeat
from torch import Tensor

from hailo_model_optimization.saitama.framework.common.saitama_definitions import DimsInfo


class StaticOnlyMeta(type):
    def __init__(cls, name, bases, namespace):
        for attr, value in namespace.items():
            if attr.startswith("__"):
                continue  # Allow special methods
            if callable(value) and not isinstance(value, staticmethod):
                raise TypeError(f"Method '{attr}' in class '{name}' must be declared as a staticmethod.")
        super().__init__(name, bases, namespace)


class Reshaping(metaclass=StaticOnlyMeta):
    @staticmethod
    def reshape_matmul_input(
        inp: Tensor,
        groups: int,
        window: DimsInfo,
        transpose=False,
    ):
        # g groups, wc number of window channels, ww number window width
        if transpose:
            x = rearrange(
                inp, "b (g wc c) h (ww w) -> b (g wc ww) h c w", g=groups, wc=window.channels, ww=window.width
            )
        else:
            x = rearrange(
                inp, "b (g wc c) h (ww w) -> b (g wc ww) h w c", g=groups, wc=window.channels, ww=window.width
            )

        return x

    @staticmethod
    def tile_inputs(
        inp_0: Tensor,
        tile_0: DimsInfo,
        inp_1: Tensor,
        tile_1: DimsInfo,
    ) -> Tuple[Tensor, Tensor]:
        # Relative tiles
        inp_0 = repeat(inp_0, "b g h w c -> b (g r1) h w c", r1=tile_0.channels)
        inp_0 = repeat(inp_0, "b g h w c -> b (r2 g) h w c", r2=tile_1.channels)

        return inp_0, inp_1

    @staticmethod
    def reshape_matmul_output(x: Tensor, groups: int, window: DimsInfo) -> Tensor:
        return rearrange(x, "b (g wc ww) h w c -> b (g wc c) h (ww w)", g=groups, wc=window.channels, ww=window.width)

    @staticmethod
    def reshape_matmul_scale(scale: Tensor, num_groups: int, transpose: bool = False) -> Tensor:
        if not transpose:
            scale = rearrange(scale, "(g c) -> 1 g 1 1 c", g=num_groups)
        else:
            scale = rearrange(scale, "(g c) -> 1 g 1 c 1", g=num_groups)
        return scale


class CommonFunctions(metaclass=StaticOnlyMeta):
    @staticmethod
    def apply_repeat_interleave(x: Tensor, repeats: DimsInfo) -> Tensor:
        """This is a interleave repeat meaning:"""

        return repeat(x, "b c h w -> b (c r1) (h r2) (w r3)", r1=repeats.channels, r2=repeats.height, r3=repeats.width)

    @staticmethod
    def apply_repeat_tile(x: Tensor, repeats: DimsInfo) -> Tensor:
        return repeat(x, "b c h w -> b (r1 c) (r2 h) (r3 w)", r1=repeats.channels, r2=repeats.height, r3=repeats.width)


class Metrics(metaclass=StaticOnlyMeta):
    @staticmethod
    def compute_snr(reference: Tensor, reconstructed: Tensor) -> Tensor:
        """
        Calculate the SNR (in dB) between a reference tensor and a reconstructed tensor.

        Both tensors must have the same shape.
        """
        # Calculate the difference (i.e., the noise)
        noise = reference - reconstructed

        # Compute the average power of the reference signal and the noise.
        signal_power = torch.mean(reference**2)
        noise_power = torch.mean(noise**2)

        # Calculate SNR in decibels.
        snr_db = 10 * torch.log10(signal_power / noise_power)
        return snr_db
