import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType
from hailo_model_optimization.saitama.translators.translator_registry import Register

NATIVE_ACTIVATION_REGISTRY = Register(nn.Module)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.RELU)(nn.ReLU)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.RELU6)(nn.ReLU6)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.SIGMOID)(nn.Sigmoid)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.TANH)(nn.Tanh)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.LINEAR)(nn.Identity)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.SILU)(nn.SiLU)
NATIVE_ACTIVATION_REGISTRY.register(ActivationType.GELU)(nn.GELU)


def hailo_reciprocal(x, epsilon=1e-10):
    s = torch.sign(x)
    return torch.reciprocal(x + s * epsilon)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.EXP)
class Exp(nn.Module):
    def __init__(self):
        super(Exp, self).__init__()

    def forward(self, x):
        return torch.exp(x)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.BIASED_DELTA)
class BiasedDelta(nn.Module):
    def __init__(self, dtype=None, device=None):
        super().__init__()
        self.activation_delta_bias = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return self.activation_delta_bias * torch.sign(torch.clamp(torch.abs(x), min=0.0))


@NATIVE_ACTIVATION_REGISTRY(ActivationType.DELTA)
class Delta(nn.Module):
    def forward(self, x):
        return 1 - torch.sign(torch.clamp(torch.abs(x), min=0.0))


@NATIVE_ACTIVATION_REGISTRY(ActivationType.INV_POS)
class InvPos(nn.Module):
    def __init__(self, inverse_act_factor=1.0):
        super(InvPos, self).__init__()
        self.inverse_act_factor = inverse_act_factor

    def forward(self, x):
        x = x * self.inverse_act_factor
        return hailo_reciprocal(x)

    def get_extra_state(self):
        return {"inverse_act_factor": self.inverse_act_factor}

    def set_extra_state(self, state):
        self.inverse_act_factor = float(state["inverse_act_factor"])


@NATIVE_ACTIVATION_REGISTRY(ActivationType.MINUS_INV_POS)
class MinusInvPos(nn.Module):
    def forward(self, x):
        return hailo_reciprocal(-x)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.RELU1)
class ReLU1(nn.Module):
    def forward(self, x):
        return F.relu6(x * 6.0) / 6.0


@NATIVE_ACTIVATION_REGISTRY(ActivationType.LESS)
class Less(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(Less, self).__init__()
        self.activation_less_values = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return (x < self.activation_less_values).to(x.dtype)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.GREATER)
class Greater(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(Greater, self).__init__()
        self.activation_greater_values = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return (x > self.activation_greater_values).to(x.dtype)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.HARDSWISH)
class NativeHardSwishActivation(nn.Module):
    def forward(self, x):
        return x * F.relu6(x + 3) / 6


@NATIVE_ACTIVATION_REGISTRY(ActivationType.SWISH)
class NativeSwishActivation(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(NativeSwishActivation, self).__init__()
        self.swish_beta = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return x * torch.sigmoid(self.swish_beta * x)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.HARDSIGMOID)
class NativeHardSigmoidActivation(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(NativeHardSigmoidActivation, self).__init__()
        self.hardsigmoid_alpha = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)
        self.hardsigmoid_beta = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return torch.clamp(self.hardsigmoid_alpha * x + self.hardsigmoid_beta, min=0.0, max=1.0)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.CLIP)
class NativeClipActivation(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(NativeClipActivation, self).__init__()
        self.clip_min = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)
        self.clip_max = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return torch.clamp(x, min=self.clip_min, max=self.clip_max)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.LEAKY)
class LeakyReLU(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(LeakyReLU, self).__init__()
        self.leaky_alpha = 0.0

    def forward(self, x):
        return F.leaky_relu(x, negative_slope=self.leaky_alpha)

    def get_extra_state(self):
        return {"leaky_alpha": self.leaky_alpha}

    def set_extra_state(self, state):
        self.leaky_alpha = float(state["leaky_alpha"])


@NATIVE_ACTIVATION_REGISTRY(ActivationType.INV_SQRT)
class InvSqrt(nn.Module):
    def __init__(self, inverse_act_factor=1.0, dtype=None, device=None):
        super(InvSqrt, self).__init__()
        self.inverse_act_factor = inverse_act_factor

    def forward(self, x):
        return hailo_reciprocal(torch.sqrt(torch.clamp(self.inverse_act_factor * x, min=1e-6)))

    def get_extra_state(self):
        return {"inverse_act_factor": self.inverse_act_factor}

    def set_extra_state(self, state):
        self.inverse_act_factor = float(state["inverse_act_factor"])


@NATIVE_ACTIVATION_REGISTRY(ActivationType.PRELU)
class PReLU(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(PReLU, self).__init__()
        self.prelu_slope = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        pos = F.relu(x)
        neg = -self.prelu_slope * F.relu(-x)
        return pos + neg


@NATIVE_ACTIVATION_REGISTRY(ActivationType.THRESHOLD)
class Threshold(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(Threshold, self).__init__()
        self.activation_threshold = 1.0

    def forward(self, x):
        return F.threshold(x, self.activation_threshold, torch.tensor(0.0))

    def get_extra_state(self):
        return {"activation_threshold": self.activation_threshold}

    def set_extra_state(self, state):
        self.activation_threshold = float(state["activation_threshold"])


@NATIVE_ACTIVATION_REGISTRY(ActivationType.MISH)
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))


@NATIVE_ACTIVATION_REGISTRY(ActivationType.POW)
class Pow(nn.Module):
    def __init__(self, dtype=None, device=None):
        super(Pow, self).__init__()
        self.pow_exponent = nn.Parameter(torch.empty(1, dtype=dtype, device=device), requires_grad=False)

    def forward(self, x):
        return torch.pow(x, self.pow_exponent)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.HDR_COMPRESSION)
class HDRCompression(nn.Module):
    def forward(self, x):
        y = torch.zeros_like(x)
        y = torch.where(x < 2**14, x / 2, y)
        y = torch.where((x >= 2**14) & (x < 2**15), x / 4 + 4096, y)
        y = torch.where((x >= 2**15) & (x < 2**18), x / 8 + 8192, y)
        y = torch.where(x >= 2**18, x / 32 + 32768, y)
        return y


@NATIVE_ACTIVATION_REGISTRY(ActivationType.RELU_POSITIVE_SQUARE)
class ReLUPositiveSquare(nn.Module):
    def forward(self, x):
        return torch.square(F.relu(x))


@NATIVE_ACTIVATION_REGISTRY(ActivationType.PWL)
class PWL(nn.Module):
    def __init__(self, thresholds, offsets, slopes):
        super(PWL, self).__init__()
        self.thresholds = torch.tensor(thresholds)
        self.offsets = torch.tensor(offsets)
        self.slopes = torch.tensor(slopes)

    def forward(self, x):
        inputs_expanded = x.unsqueeze(-1)
        mask = torch.cat(
            [
                (inputs_expanded < self.thresholds[0]).to(torch.float16),
                ((inputs_expanded >= self.thresholds[:-1]) & (inputs_expanded < self.thresholds[1:])).to(torch.float16),
                (inputs_expanded >= self.thresholds[-1]).to(torch.float16),
            ],
            dim=-1,
        )
        slopes = torch.sum(mask * self.slopes, dim=-1)
        offsets = torch.sum(mask * self.offsets, dim=-1)
        return x * slopes + offsets


@NATIVE_ACTIVATION_REGISTRY(ActivationType.EXP_DECOMPOSE)
class ExpDecompose(nn.Module):
    def __init__(self, mask=None):
        super(ExpDecompose, self).__init__()

        if mask is not None:
            larger_mask = np.maximum(mask, -mask[::-1])[len(mask) // 2 :]
            threshold = mask[1:-1]
            offsets = 3 * torch.cat([larger_mask[::-1], larger_mask[1:]], dim=-1)
            slopes = 0.0
            self.pwl = PWL(thresholds=threshold, offsets=offsets, slopes=slopes)
        else:
            self.pwl = lambda x: 6.0 * (2.0 ** torch.floor(torch.log2(torch.abs(x))))

    def forward(self, x):
        return self.pwl(x)


@NATIVE_ACTIVATION_REGISTRY(ActivationType.SHIFT)
class Shift(nn.Module):
    def __init__(self, mask=None):
        super(Shift, self).__init__()
        if mask is not None:
            larger_mask = np.maximum(mask, -mask[::-1])[len(mask) // 2 :]
            buffer = larger_mask[-1] / 2**15
            threshold = 4 * larger_mask[:-1] - buffer
            offsets = -3.0
            slopes = 1.0 / larger_mask
            self.pwl = PWL(thresholds=threshold, offsets=offsets, slopes=slopes)
        else:
            self.pwl = lambda x: 2.0 * x / (2.0 ** torch.floor(torch.log2(torch.abs(x)))) - 3.0

    def forward(self, x):
        return self.pwl(x)


def activation_factory(activation: ActivationType, **kwargs):
    if activation in NATIVE_ACTIVATION_REGISTRY:
        return NATIVE_ACTIVATION_REGISTRY[activation](**kwargs)
    raise NotImplementedError(f"Activation type {activation} is not supported")
