import torch

from hailo_model_optimization.saitama.framework.common.fake_quant import QuantEqWeight
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    DimsInfo,
    Encoding,
    IOPrecisionConfig,
)
from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule
from hailo_model_optimization.saitama.framework.common.utils import init_encoding, qtype_to_range


class SaitamaConstInput(SaitamaModule):
    def __init__(self, shape, tile: DimsInfo, precision_config: IOPrecisionConfig, dtype=None, device=None):
        super().__init__()
        value = torch.zeros(shape, dtype=dtype, device=device)
        channels_axis = 0
        self.tile = tile
        qtype_min, qtype_max = qtype_to_range(precision_config.output_qtype)
        self.value = QuantEqWeight(
            quant_min=qtype_min,
            quant_max=qtype_max,
            value=value,
            num_groups=1,
            channels=value.shape[channels_axis],
            axis=channels_axis,
            is_independent_encoding=False,
        )

    def forward(self, **kwargs):
        value = self.value.get_weight()
        x = value.repeat(self.tile.as_chw_tuple())
        return x.unsqueeze(0)

    def forward_encoding(self, **kwargs) -> Encoding:
        encoding = self.value.forward_encoding(**kwargs)
        if self.tile.channels != 1:
            encoding = init_encoding(
                encoding.scale_by_group,
                encoding.scale_repeats * self.tile.channels,
                encoding.zero_point_by_group,
                encoding.zero_point_repeats * self.tile.channels,
                encoding.equalization_vector.repeat(self.tile.channels),
            )
        return encoding
