from abc import ABC

import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.fake_quant import BaseQuant, QuantWeight
from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding, LayerOutput


class SaitamaModule(nn.Module, ABC):
    """
    BaseQUnit is an abstract base class for quantization units in a neural network. It provides methods for registering
    and importing weight and encoding parameters, as well as enabling or disabling quantization for weights and data.

    any class that inherits from BaseQUnit should implement the following:
    - forward, which forwards data through the unit. (if the module implements forward at all)
    - forward_encoding (optional - future), which forwards encoding through the unit.
    """

    def enable_weight_quantization(self, enable=True):
        enable_weight_quantization(self, enable)

    def enable_data_quantization(self, enable=True):
        enable_data_quantization(self, enable)

    def enable_quantization(self, enable=True):
        enable_quantization(self, enable)

    def freeze_quant_weights(self):
        freeze_quant_weights(self)

    def enable_encoded_forward(self, value=True):
        enable_encoded_forward(self, value)

    def forward_encoding(self, *encs: Encoding, **kwargs) -> Encoding:
        raise NotImplementedError(f"Encoding is not supported for {self.__class__.__name__}")

    def forward(self, *args, **kwargs) -> LayerOutput:
        raise NotImplementedError(f"Data is not supported for {self.__class__.__name__}")


def freeze_quant_weights(module: nn.Module):
    if isinstance(module, QuantWeight):
        module.freeze()
    else:
        for child in module.children():
            freeze_quant_weights(child)


def enable_data_quantization(module: nn.Module, value: bool):
    for child in module.children():
        if isinstance(child, BaseQuant) and not isinstance(child, QuantWeight):
            child.enable_quantization(value)
        else:
            enable_data_quantization(child, value)


def enable_weight_quantization(module: nn.Module, value: bool):
    for child in module.children():
        if isinstance(child, QuantWeight):
            child.enable_quantization(value)
        else:
            enable_weight_quantization(child, value)


def enable_quantization(module: nn.Module, value: bool):
    enable_data_quantization(module, value)
    enable_weight_quantization(module, value)


def enable_encoded_forward(module: nn.Module, value: bool):
    for child in module.children():
        if isinstance(child, BaseQuant):
            child.set_encoded(value)
        else:
            enable_encoded_forward(child, value)
    if hasattr(module, "set_encoded"):
        module.set_encoded(value)
