from typing import List

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.utils import init_encoding
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_base import ForwarderBase


class ForwarderSliceAxis(ForwarderBase):
    def __init__(self, axis, start, end):
        super().__init__()
        self.axis = axis
        self.start = start
        self.end = end

    def _slice(self, x: torch.Tensor, axis):
        x = x.narrow(axis, self.start, self.end - self.start)
        return x

    def forward(self, x: torch.Tensor, **kwargs):
        return self._slice(x, self.axis)

    def forward_encoding(self, encoding: Encoding, verify_encoding=False, **kwargs):
        if self.axis == 1:
            if verify_encoding:
                assert encoding.scale_by_group.size(0) == 1
                assert encoding.zero_point_by_group.size(0) == 1
            length = self.end - self.start

            scale_by_group = encoding.scale_by_group
            scale_repeats = length

            zero_point_by_group = encoding.zero_point_by_group
            zero_point_repeats = length

            equalization_vector = self._slice(encoding.equalization_vector, 0)
            encoding = init_encoding(
                scale_by_group,
                scale_repeats,
                zero_point_by_group,
                zero_point_repeats,
                equalization_vector,
            )
        return encoding

    def extra_repr(self):
        return f"axis={self.axis}, start={self.start}, end={self.end}"


class ForwarderSlice(ForwarderBase):
    def __init__(self, axes, starts, ends):
        super().__init__()
        self.slices: List[ForwarderSliceAxis] = nn.ModuleList()
        assert len(axes) == len(starts) == len(ends)
        for axis, s, e in zip(axes, starts, ends):
            self.slices.append(ForwarderSliceAxis(axis, s, e))

    def forward(self, x: torch.Tensor, **kwargs):
        for slice_ in self.slices:
            x = slice_(x)
        return x

    def forward_encoding(self, encoding: Encoding, **kwargs):
        for slice_ in self.slices:
            encoding = slice_.forward_encoding(encoding, **kwargs)
        return encoding
