from hailo_model_optimization.saitama.framework.common.fake_quant import StaticFakeQuant
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    Encoding,
    IOPrecisionConfig,
)
from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule
from hailo_model_optimization.saitama.framework.common.utils import init_encoding, qtype_to_range


class SaitamaIO(SaitamaModule):
    output_quantizer: StaticFakeQuant
    is_independent_scale: bool  # indicates if the scale should be a parameter (and can potentially be trained)

    def __init__(self, out_channels, precision_config: IOPrecisionConfig, device=None, dtype=None, **kwargs):
        super().__init__(**kwargs)
        self.out_channels = out_channels
        self.init_precision(precision_config)
        self.initialize_output_quantizer(device=device, dtype=dtype)
        self._is_encoded = False

    def init_precision(self, precision_config: IOPrecisionConfig):
        self.output_qtype = precision_config.output_qtype

    def initialize_output_quantizer(self, device=None, dtype=None):
        output_channels = self.out_channels
        quant_min, quant_max = qtype_to_range(self.output_qtype)
        output_quantizer = StaticFakeQuant(
            quant_min=quant_min,
            quant_max=quant_max,
            channels=output_channels,
            num_groups=(output_channels, 1),
            axis=1,
            is_independent_encoding=(self.is_independent_scale, False),
            dtype=dtype,
            device=device,
        )
        self.output_quantizer = output_quantizer

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded


class SaitamaInput(SaitamaIO):
    is_independent_scale = True  # Source of the model, can be trained

    def forward(self, x, **kwargs):
        if self._is_encoded:
            x = self.output_quantizer._encode(x)
        return self.output_quantizer(x)

    def forward_encoding(self, **kwargs):
        output_scale = self.output_quantizer.scale.min()
        equalization_vector = self.output_quantizer.scale / output_scale
        return init_encoding(
            scale_by_group=output_scale.view(1),
            scale_repeats=self.out_channels,
            zero_point_by_group=self.output_quantizer.zero_point,
            zero_point_repeats=self.output_quantizer.channels_per_group_zero_point,
            equalization_vector=equalization_vector,
        )


class SaitamaOutput(SaitamaIO):
    is_independent_scale = False  # Final node, scale is affected by predecessors

    def forward(self, x, **kwargs):
        result = self.output_quantizer(x)
        if self._is_encoded:
            result = self.output_quantizer._decode(x)
        return result

    def forward_encoding(self, encoding: Encoding, **kwargs):
        self.output_quantizer.forward_encoding(encoding, **kwargs)
