from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.softmax_mask_on_mac_op import SoftmaxMaskOnMacOp
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, LayerType


class HailoSoftmaxMaskOnMac(HailoElementwiseMultOnMac):
    """
    This layer represent a masked operation for softmax layer.
    input 0 is the input to the softmax layer, and input 1 is a binary mask.
    """

    _hn_type = LayerType.ELEMENTWISE_MULT

    def __init__(
        self,
        name: str,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        super().__init__(name=name, activation=activation, logger=logger, **kwargs)
        self.ew_mult_op = SoftmaxMaskOnMacOp(name=f"{name}/elementwise_mult_op", logger=logger)
        self._layer_flow = self._build_flow()  # TODO: a bit hacky

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = super().get_default_params()
        defaults["is_softmax_mask"] = True
        return dict(defaults)

    def to_hn(self, out_degree=None):
        hn = super().to_hn(out_degree)
        hn.setdefault("params", {})
        hn["params"]["is_softmax_mask"] = True
        return hn

    def enforce_io_encoding(self, training=False, **kwargs):
        self.set_output_scale(self.input_scales[0], 0)
        self.set_output_zero_point(self.input_zero_points[0], 0)

    def create_hw_params(self, weights_clipping, optimization_target, hw_shifts=None):
        self._enforce_output_encoding()
        # pre_acc_shift = hw_shifts[0] if hw_shifts is not None else hw_shifts
        # TODO: propagate the pre_acc_shift to the ew_mult_op
        self.ew_mult_op.create_hw_params(force_shift=0)
        self.enforce_internal_encoding()

        self.bias_add_op.pre_acc_shift = self.ew_mult_op._multiplier_shift
        self.bias_add_op.create_hw_params()
        self.enforce_internal_encoding()

        self._shift_delta = 0

        self.act_op.create_hw_params(self.optional_reduce_sum_op.output_scale, optimization_target, nudging=False)
        self.output_op.create_hw_params()
