from typing import List

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.utils import init_encoding
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_base import ForwarderBase
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_slice import ForwarderSliceAxis


class ForwarderSplitter(ForwarderBase):
    def __init__(self, axis: int, splits: List[int], groups):
        super().__init__()
        idx = 0
        self.slices: List[ForwarderSliceAxis] = nn.ModuleList()
        for split in splits:
            if split <= 0:
                raise ValueError("Split value must be positive")
            slice_i = ForwarderSliceAxis(axis, idx, idx + split)
            idx += split
            self.slices.append(slice_i)
        self.groups = groups
        self.axis = axis
        assert groups == 1 or (
            axis == 1 and all(splits[0] == split for split in splits)
        ), "Splitter supports only groups=1 or groups>1 with axis=1 and equal splits"

    def _reorder_groups(self, x: torch.Tensor, axis: int, num_slices: int):
        y0 = x.unflatten(axis, (self.groups, x.shape[axis] // self.groups))
        y1 = y0.unflatten(axis + 1, (num_slices, y0.shape[axis + 1] // num_slices))
        y2 = y1.transpose(axis, axis + 1)
        y3 = y2.flatten(axis, axis + 2)
        return y3

    def forward(self, x: torch.Tensor, **kwargs):
        if self.groups != 1:
            x = self._reorder_groups(x, self.axis, len(self.slices))
        y = [slice_(x) for slice_ in self.slices]
        return y

    def forward_encoding(self, encoding: Encoding, **kwargs):
        if self.groups != 1 and self.axis == 1:
            num_slices = len(self.slices)
            eq_vector = self._reorder_groups(encoding.equalization_vector, 0, num_slices)
            encoding = init_encoding(
                scale_by_group=encoding.scale_by_group,
                scale_repeats=encoding.scale_repeats,
                zero_point_by_group=encoding.zero_point_by_group,
                zero_point_repeats=encoding.zero_point_repeats,
                equalization_vector=eq_vector,
            )

        return [slice_.forward_encoding(encoding, **kwargs) for slice_ in self.slices]
