import torch

from hailo_model_optimization.saitama.framework.apu_modules.apu_base import APUBase


class APUReduceMax(APUBase):
    def __init__(self, axes=None, groups=None):
        super().__init__()
        self.axes = axes if axes is not None else [1]
        self.groups = groups if groups is not None else [1, 1, 1]

    def forward(self, x: torch.Tensor, **kwargs):
        for axis in self.axes:
            groups = self.groups[axis - 1]
            x = x.unflatten(axis, (groups, x.shape[axis] // groups))
            x = torch.max(x, axis=axis + 1, keepdim=True).values
            x = torch.squeeze(x, axis=axis + 1)
        return x
