from typing import List, Tuple

import torch
from einops import rearrange

from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.utils import init_encoding
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_base import ForwarderBase


class ForwarderConcat(ForwarderBase):
    def __init__(self, axis=1, group_sizes: Tuple[int, ...] = (1,)):
        super().__init__()
        self.axis = axis
        self.group_sizes = group_sizes
        self.groups = sum(self.group_sizes)
        assert sum(self.group_sizes) == 1 or axis == 1, "Concat with groups supports only axis=1"

    def _handle_concat(self, inputs: List[torch.Tensor], axis: int):
        if self.groups == 1:
            return torch.cat(inputs, dim=axis)
        else:
            splits = []
            for inp in inputs:
                group_inp = rearrange(inp, "b (g c) h w -> b g c h w", g=self.groups)
                split_group_inp = torch.split(group_inp, self.group_sizes, dim=1)
                splits.append(split_group_inp)
            x = torch.cat(
                [rearrange(chunk, "b g c h w -> b (g c) h w") for group in zip(*splits) for chunk in group], dim=1
            )

            return x

    def forward(self, *inputs: torch.Tensor):
        return self._handle_concat(inputs, self.axis)

    def forward_encoding(self, *encs: Encoding, verify_encoding=False, **kwargs):
        if self.axis != 1:
            if verify_encoding:
                assert torch.allclose(encs[0].scale_by_group, [enc.scale_by_group for enc in encs])
                assert all(encs[0].scale_repeats == enc.scale_repeats for enc in encs)
                assert torch.allclose(encs[0].zero_point_by_group, [enc.zero_point_by_group for enc in encs])
                assert all(encs[0].zero_point_repeats == enc.zero_point_repeats for enc in encs)
                assert torch.allclose(encs[0].scale_by_channel, [enc.scale_by_channel for enc in encs])
                assert torch.allclose(encs[0].zero_point_by_channel, [enc.zero_point_by_channel for enc in encs])
                assert torch.allclose(encs[0].equalization_vector, [enc.equalization_vector for enc in encs])
            return encs[0]
        else:
            if verify_encoding:
                assert all(enc.scale_by_group.size(0) == 1 for enc in encs)

            group_scale_by_channel = self._handle_concat(
                [enc.scale_by_channel / enc.equalization_vector for enc in encs], axis=0
            )
            zero_point_by_channel = self._handle_concat([enc.zero_point_by_channel for enc in encs], axis=0)
            equalization_vector = self._handle_concat([enc.equalization_vector for enc in encs], axis=0)
            equalization_vector /= group_scale_by_channel[0] / group_scale_by_channel

            encoding = init_encoding(
                scale_by_group=group_scale_by_channel[:1],
                scale_repeats=sum([enc.scale_repeats for enc in encs]),
                zero_point_by_group=zero_point_by_channel,
                zero_point_repeats=1,
                equalization_vector=equalization_vector,
            )
            return encoding

    def extra_repr(self):
        return f"axis={self.axis}, self.group_sizes={self.group_sizes}"
