import torch

from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule


class FusedBase(SaitamaModule):
    pass


class SubClusterModule(FusedBase):
    def __init__(self, mac, apu, is_activation_only=False):
        super().__init__()
        self.mac = mac
        self.apu = apu
        self.is_activation_only = is_activation_only

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
        mac_res = self.mac(*inputs)
        apu_res = self.apu(mac_res)
        return apu_res

    def forward_encoding(self, *encs: Encoding, **kwargs) -> Encoding:
        mac_res = self.mac.forward_encoding(*encs, **kwargs)
        apu_res = self.apu.forward_encoding(mac_res, **kwargs)
        return apu_res
