from typing import List

import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType
from hailo_model_optimization.saitama.framework.apu_modules.apu_activation import APUActivation
from hailo_model_optimization.saitama.framework.common.saitama_definitions import PrecisionConfigT
from hailo_model_optimization.saitama.framework.fused_modules.fused_base import FusedBase
from hailo_model_optimization.saitama.framework.mac_modules.mac_ew_add import MACEWSub
from hailo_model_optimization.saitama.framework.mac_modules.mac_ew_mult import MACEWMult
from hailo_model_optimization.saitama.framework.mac_modules.mac_reduce import MACReduceSum


class MACSoftmax(FusedBase):
    def __init__(
        self,
        channels: int,
        precision_config: List[PrecisionConfigT] = None,
        dtype=None,
        device=None,
    ):
        super().__init__()
        if precision_config is None:
            mac_cfg = acu_cfg = None
        else:
            mac_cfg, acu_cfg = precision_config

        d_d = {"dtype": dtype, "device": device}

        self.mac = nn.ModuleList(
            [
                MACEWSub(channels, precision_config=mac_cfg, **d_d),
                MACReduceSum(channels, precision_config=mac_cfg, **d_d),
                MACEWMult(channels, precision_config=mac_cfg, **d_d),
            ]
        )

        self.apu = nn.ModuleList(
            [
                APUActivation(channels, activation=ActivationType.EXP, precision_config=acu_cfg, **d_d),
                APUActivation(channels, activation=ActivationType.INV_POS, precision_config=acu_cfg, **d_d),
                APUActivation(channels, activation=ActivationType.LINEAR, precision_config=acu_cfg, **d_d),
            ]
        )

    def forward(self, x, **kwargs):
        max_val = torch.max(x, dim=1, keepdim=True)  # forward encoding
        post_sub = self.mac[0](x, max_val)  # EwSub
        exp_vals = self.apu[0](post_sub)  # exp(post)
        # TODO  MacSum, need to think how to give the -1 value maybe is not worth, need to check for the -128
        sum_val = self.mac[1](exp_vals)  # sum(exp_vals)
        inv_vals = self.apu[1](sum_val)  # inv(sum(exp_vals))
        softmax = self.mac[2](exp_vals, inv_vals)  # exp_vals * inv(sum(exp_vals))
        rescale_softmax = self.apu[2](softmax)  # reduce to the output scale (also there are nice shifts :)
        return rescale_softmax
