import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import APU_MANTISSA_BITS, ActivationType
from hailo_model_optimization.saitama.framework.apu_modules.apu_activation import APUActivation
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    APUPrecisionConfig,
    MACPrecisionConfig,
    QType,
)
from hailo_model_optimization.saitama.framework.mac_modules.mac_conv2d import MACConv2d
from hailo_model_optimization.saitama.translators.base_translator import BaseLayerTranslator
from hailo_model_optimization.saitama.translators.translator_registry import Register


class TorchModuleTranslator(BaseLayerTranslator):
    @classmethod
    def enable_stats_hook(cls, module: nn.Module):
        module.max_range = torch.zeros(1)
        module.min_range = torch.zeros(1)
        module.register_forward_hook(cls.hook_output_range)

    @staticmethod
    def hook_output_range(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
        max_ = torch.max(output).detach()
        min_ = torch.min(output).detach()
        module.max_range = torch.maximum(module.max_range, max_)
        module.min_range = torch.maximum(module.min_range, min_)
        module.output_shape = output.shape


# region Register Factorties
TORCH_MODULES_REGISTRY = Register(TorchModuleTranslator)


@TORCH_MODULES_REGISTRY(nn.Linear)
class LinearTranslator(TorchModuleTranslator):
    @classmethod
    def translate(cls, layer: nn.Linear, dtype=None, device=None):
        has_bias = layer.bias is not None
        input_qtype = QType(8, False)
        weight_qtype = QType(8, True)
        accumulator_qtype = QType(16, True)
        output_dtype = QType(8, False)
        mac_cfg = MACPrecisionConfig(input_qtype, weight_qtype, accumulator_qtype, "double_scale_initialization", 1)
        apu_cfg = APUPrecisionConfig(accumulator_qtype, output_dtype, 1)

        new_linear = MACConv2d(
            layer.in_features,
            layer.out_features,
            1,
            bias=True,
            precision_config=mac_cfg,
            dtype=dtype,
            device=device,
        )

        activation = APUActivation(layer.out_features, ActivationType.LINEAR, apu_cfg, dtype=dtype, device=device)

        with torch.no_grad():
            maxval = torch.maximum(torch.abs(layer.min_range), layer.max_range)
            accumulator_scale = maxval / 2 ** (accumulator_qtype.bits - accumulator_qtype.signed)
            kernel_scale = layer.weight.abs().max() / 2 ** (weight_qtype.bits - weight_qtype.signed)
            eq_vec_in = torch.ones(layer.in_features)
            eq_vec_out = torch.ones(layer.out_features)

            state_dict = {
                "kernel.mac_shift": torch.ones(1),
                "accumulator_quantizer.scale": torch.repeat_interleave(accumulator_scale, layer.out_features),
                "accumulator_quantizer.zero_point": torch.zeros(layer.out_features),
                "kernel.weight": layer.weight.view(layer.out_features, layer.in_features, 1, 1),
                "kernel.scale": kernel_scale.view(1),
                "kernel.equalization_vector_in": eq_vec_in,
                "kernel.equalization_vector_out": eq_vec_out,
                "kernel.zero_point": torch.zeros(1),
            }

            bias_dict = {
                "bias.weight": layer.bias if has_bias else torch.zeros(layer.out_features),
                "bias.scale": torch.repeat_interleave(accumulator_scale, layer.out_features),
                "bias.zero_point": torch.zeros(layer.out_features),
            }
            state_dict.update(bias_dict)

            output_scale = (layer.max_range - layer.min_range) / 2 ** (output_dtype.bits - output_dtype.signed)
            output_zero_point = layer.min_range / output_scale
            m, e = torch.frexp(accumulator_scale / output_scale)
            m_bits = APU_MANTISSA_BITS
            e_bias = activation.second_bias + activation.first_bias - m_bits
            act_state_dict = {
                "output_quantizer.scale": torch.repeat_interleave(output_scale, layer.out_features),
                "output_quantizer.zero_point": torch.repeat_interleave(output_zero_point, layer.out_features),
                "slopes.mantissas": (m * 2**m_bits).view(1, 1),
                "slopes.exponents": (torch.abs(e) + e_bias).view(1, 1),
                "offsets.scale": torch.tensor(output_scale / 2**activation.second_bias),
                "offsets.zero_point": torch.tensor(output_zero_point * 2**activation.second_bias),
                "offsets.weight": torch.zeros(1),
                "offsets.final_shift": torch.tensor([3]),
                # NOTE: not sure about these keys
                "thresholds.weight": torch.zeros(0),
                "thresholds.scale": accumulator_scale,
                "thresholds.zero_point": torch.zeros(1),
                "slopes.weight": torch.ones(1, 1),
                "slopes.scale": accumulator_scale / output_scale,
                "slopes.zero_point": torch.zeros(1),
                "slopes.q_groups_factor": torch.ones(1),
            }

        new_linear.load_state_dict(state_dict)
        new_linear.accumulator_quantizer.axis = len(layer.output_shape) - 1
        activation.load_state_dict(act_state_dict)
        activation.output_quantizer.axis = len(layer.output_shape) - 1

        layer = nn.Sequential(new_linear, activation)
        layer.weight = new_linear.kernel.get_weight()

        return layer


@TORCH_MODULES_REGISTRY(nn.Conv2d)
class Conv2dTranslator(TorchModuleTranslator):
    @classmethod
    def translate(cls, layer: nn.Conv2d, dtype=None, device=None):
        has_bias = layer.bias is not None
        input_qtype = QType(8, False)
        weight_qtype = QType(8, True)
        accumulator_qtype = QType(16, True)
        output_dtype = QType(8, False)
        mac_cfg = MACPrecisionConfig(input_qtype, weight_qtype, accumulator_qtype, "double_scale_initialization", 1)
        apu_cfg = APUPrecisionConfig(accumulator_qtype, output_dtype, 1)
        new_conv = MACConv2d(
            layer.in_channels,
            layer.out_channels,
            layer.kernel_size,
            stride=layer.stride,
            padding=layer.padding,
            dilation=layer.dilation,
            groups=layer.groups,
            padding_mode=layer.padding_mode,
            bias=True,
            precision_config=mac_cfg,
            dtype=dtype,
            device=device,
        )

        activation = APUActivation(1, layer.out_channels, ActivationType.LINEAR, apu_cfg, dtype=dtype, device=device)

        with torch.no_grad():
            maxval = torch.maximum(torch.abs(layer.min_range), layer.max_range)
            accumulator_scale = maxval / 2 ** (accumulator_qtype.bits - accumulator_qtype.signed)
            kernel_scale = layer.weight.abs().max() / 2 ** (weight_qtype.bits - weight_qtype.signed)
            eq_vec_in = torch.ones(layer.in_channels)
            eq_vec_out = torch.ones(layer.out_channels)

            state_dict = {
                "kernel.mac_shift": torch.ones(1),
                "accumulator_quantizer.scale": torch.repeat_interleave(accumulator_scale, layer.out_channels),
                "accumulator_quantizer.zero_point": torch.zeros(layer.out_channels),
                "kernel.weight": layer.weight,
                "kernel.scale": kernel_scale.view(1),
                "kernel.equalization_vector_in": eq_vec_in,
                "kernel.equalization_vector_out": eq_vec_out,
                "kernel.zero_point": torch.zeros(1),
            }

            bias_dict = {
                "bias.weight": layer.bias if has_bias else torch.zeros(layer.out_channels),
                "bias.scale": torch.repeat_interleave(accumulator_scale, layer.out_channels),
                "bias.zero_point": torch.zeros(layer.out_channels),
            }
            state_dict.update(bias_dict)

            output_scale = (layer.max_range - layer.min_range) / 2 ** (output_dtype.bits - output_dtype.signed)
            output_zero_point = layer.min_range / output_scale
            m, e = torch.frexp(accumulator_scale / output_scale)
            m_bits = APU_MANTISSA_BITS
            e_bias = activation.second_bias + activation.first_bias - m_bits
            act_state_dict = {
                "output_quantizer.scale": torch.repeat_interleave(output_scale, layer.out_channels),
                "output_quantizer.zero_point": torch.repeat_interleave(output_zero_point, layer.out_channels),
                "slopes.mantissas": (m * 2**m_bits).view(1, 1),
                "slopes.exponents": (torch.abs(e) + e_bias).view(1, 1),
                "offsets.scale": torch.tensor(output_scale / 2**activation.second_bias),
                "offsets.zero_point": torch.tensor(output_zero_point * 2**activation.second_bias),
                "offsets.weight": torch.zeros(1),
                "offsets.final_shift": torch.tensor([3]),
                # NOTE: not sure about these keys
                "thresholds.weight": torch.zeros(0),
                "thresholds.scale": accumulator_scale,
                "thresholds.zero_point": torch.zeros(1),
                "slopes.weight": torch.ones(1, 1),
                "slopes.scale": accumulator_scale / output_scale,
                "slopes.zero_point": torch.zeros(1),
                "slopes.q_groups_factor": torch.ones(1),
            }
        new_conv.load_state_dict(state_dict)
        activation.load_state_dict(act_state_dict)

        layer = nn.Sequential(new_conv, activation)

        layer.weight = new_conv.kernel.get_weight()

        return layer


# endregion Register Factorties

# NOTE: can't be really translated, same activation is used multiple times
# @TORCH_MODULES_REGISTRY(nn.SiLU)
# class SiLUTranslator(TorchModuleTranslator):
#     @classmethod
#     def translate(cls, layer, *, dtype=None, device=None):
#         input_qtype = QType(8, False)
#         weight_qtype = QType(8, True)
#         accumulator_qtype = QType(16, True)
#         output_dtype = QType(8, False)
#         mac_cfg = MACPrecisionConfig(input_qtype, weight_qtype, accumulator_qtype, "double_scale_initialization", 1)
#         apu_cfg = APUPrecisionConfig(accumulator_qtype, output_dtype, 1)
#         channels = layer.output_shape[1]
#         # new_conv = MACConv2d(
#         #     channels,
#         #     channels,
#         #     1,
#         #     groups=channels,
#         #     bias=True,
#         #     precision_config=mac_cfg,
#         #     dtype=dtype,
#         #     device=device,
#         # )
#         activation = APUActivation(9, channels, ActivationType.SILU, apu_cfg, dtype=dtype, device=device)
#         # if channels == 128 or channels == 1280:
#         #     breakpoint()

#         with torch.no_grad():
#             slopes = torch.tensor([0.0, -0.02893, -0.08676, 0.078891, 0.339243, 0.638031, 0.911426, 1.086621, 1.0])

#             maxval = torch.maximum(torch.abs(layer.min_range), layer.max_range)
#             accumulator_scale = maxval / 2 ** (accumulator_qtype.bits - accumulator_qtype.signed)

#             output_scale = (layer.max_range - layer.min_range) / 2 ** (output_dtype.bits - output_dtype.signed)
#             output_zero_point = layer.min_range / output_scale
#             m, e = torch.frexp(slopes * accumulator_scale / output_scale)
#             m_bits = APU_MANTISSA_BITS
#             e_bias = activation.second_bias + activation.first_bias - m_bits

#             offsets = torch.tensor([0.0, -0.18178, -0.40690, -0.18372, -0.01870, -0.01170, -0.17351, -0.40650, 0.0])
#             thresholds = torch.tensor([-6.0, -3.89304, -1.34733, -0.63382, -0.02343, 0.59188, 1.32987, 3.89463])

#             act_state_dict = {
#                 "output_quantizer.scale": torch.repeat_interleave(output_scale, channels),
#                 "output_quantizer.zero_point": torch.repeat_interleave(output_zero_point, channels),
#                 "slopes.mantissas": (m * 2**m_bits).view(1, -1),
#                 "slopes.exponents": (torch.abs(e) + e_bias).view(1, -1),
#                 "offsets.scale": torch.tensor(output_scale / 2**activation.second_bias),
#                 "offsets.zero_point": torch.tensor(output_zero_point * 2**activation.second_bias),
#                 "offsets.weight": offsets,
#                 # NOTE: not sure about these keys
#                 "thresholds": thresholds,
#                 "slopes.weight": slopes.view(1, -1),
#                 "slopes.scale": accumulator_scale / output_scale,
#                 "slopes.zero_point": torch.zeros(1),
#                 "slopes.q_groups_factor": torch.ones(1),
#             }
#         activation.load_state_dict(act_state_dict)

#         return activation
