from abc import abstractmethod

import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasMode
from hailo_model_optimization.saitama.framework.common.fake_quant import (
    AccumulatorFakeQuant,
    QuantBias,
    QuantBiasDecomposed,
)
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    MACPrecisionConfig,
    QType,
)
from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule
from hailo_model_optimization.saitama.framework.common.utils import (
    qtype_to_range,
)


class MACBase(SaitamaModule):
    """
    BaseMAC is an abstract base class for MAC (Multiply-Accumulate) operations in based on the hailo chip architecture.
    The basic mac logic includes 3 mains components:
    1. Quantized bias, which can be either a single scale decomposition or double scale initialization. (Can it be disabled?)
    2. output_quantizer - a fake quantizer that quantizes the output of the MAC operation. it implements the accumulator wraparound.

    When subclassing BaseMAC, the following methods should be implemented:
    - initialize functions for any additional encodings or weights that may apply to the specific MAC unit.
        Make sure the initialization functions are called in the __init__ method.
    - import functions for any additional encodings or weights that may apply to the specific MAC unit.
        make sure to register the import functions in the __init__ method.
        `register_weight_import_method` or `register_encoding_import_method` should be used for this purpose.
    - forward_mac function, which implements the forward logic for the MAC unit data path.
        bias and output_quantizer are already implemented as part of forward_data
    - forward_encoding (optional) will be relevant once we implement the online encoding logic.
    """

    # Initilized in `initlize` methods
    bias: QuantBias
    output_quantizer: AccumulatorFakeQuant
    mult_residue: torch.Tensor

    # Initialized in init_precision
    bias_mode: BiasMode
    input_qtype: QType
    weight_qtype: QType
    accumulator_qtype: QType

    # initialized in __init__
    out_channels: int

    mac_shift: int
    # region initialization

    def __init__(
        self,
        out_channels,
        bias: bool = True,
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.out_channels = out_channels

        factory_kwargs = {"device": device, "dtype": dtype}
        if precision_config is None:
            raise ValueError("Precision configuration is required for MAC unit")
        self.mac_shift = nn.Parameter(torch.zeros(1, device=device, dtype=dtype))
        self.init_precision(precision_config)
        self.initialize_accumulator_quantizer(**factory_kwargs)
        if bias:
            self.initialize_bias(**factory_kwargs)
        else:
            self.bias = None

    def init_precision(self, precision_config: MACPrecisionConfig):
        self.bias_mode = BiasMode(precision_config.bias_mode)
        self.input_qtype = precision_config.input_qtype
        self.weight_qtype = precision_config.weight_qtype
        self.accumulator_qtype = precision_config.accumulator_qtype
        self.quantization_groups = precision_config.quantization_groups

    def initialize_accumulator_quantizer(self, device=None, dtype=None):
        output_channels = self.out_channels
        quant_min, quant_max = qtype_to_range(self.accumulator_qtype)
        accumulator_quantizer = AccumulatorFakeQuant(
            quant_min=quant_min,
            quant_max=quant_max,
            channels=output_channels,
            num_groups=output_channels,
            axis=1,
            is_independent_encoding=False,
            dtype=dtype,
            device=device,
        )
        self.accumulator_quantizer = accumulator_quantizer

    def initialize_bias(self, device=None, dtype=None):
        bias_shape = self.get_bias_shape()

        quant_min, quant_max = qtype_to_range(self.accumulator_qtype)

        self.bias = self.create_bias_weight(quant_min, quant_max, bias_shape, device=device, dtype=dtype)

    def create_bias_weight(self, quant_min, quant_max, bias_shape, device=None, dtype=None):
        bias = torch.zeros(bias_shape, device=device, dtype=dtype)
        if self.bias_mode == BiasMode.single_scale_decomposition:
            bias = QuantBiasDecomposed(
                quant_min=quant_min,
                quant_max=quant_max,
                value=bias,
                num_groups=bias_shape,
                channels=bias_shape,
                axis=0,
                is_independent_encoding=False,
            )
        elif self.bias_mode == BiasMode.double_scale_initialization:
            bias = QuantBias(
                quant_min=quant_min,
                quant_max=quant_max,
                value=bias,
                num_groups=bias_shape,
                channels=bias_shape,
                axis=0,
                is_independent_encoding=False,
            )
        else:
            raise ValueError(f"Unsupported bias mode: {self.bias_mode}")
        return bias

    def get_bias_shape(self):
        # NOTE: This should be overridden in the derived classes if the hw doesn't implement bias as vector
        return self.out_channels

    # endregion initialization

    # region forward logic

    @abstractmethod
    def forward_mac(self, *args, **kwargs):
        pass

    def forward(self, *args, **kwargs):
        mac_res = self.forward_mac(*args, **kwargs)
        return self.accumulator_quantizer(mac_res)

    # endregion forward logic

    # What type of configurations are common for all MAC units?
    # 1. MAC shift
    # 2. precision
    #   a. input bits
    #   b. weight bits
    #   c. accumulator bits (usually derive from input bits, although might differ if we use decomposition layer)
    #   d. precision definitions should be relevant only for clip and wraparound
    # 3. technically bias is always supported, but it is not always used
    #   a. as initialized value ("double scale" initialization)
    #   b. as mult product (single scale decomposition)
    #   c. for inference can always be initialized to value, for training it might be a bit trickier
