from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.softmax_mask_op import SoftmaxMaskOp
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, LayerType


class HailoSoftmaxMask(HailoElementwiseMult):
    """
    This layer represent a masked operation for softmax layer.
    input 0 is the input to the softmax layer, and input 1 is a binary mask.
    """

    _hn_type = LayerType.ELEMENTWISE_MULT

    def __init__(
        self,
        name: str,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        super().__init__(name=name, activation=activation, logger=logger, **kwargs)
        self.ew_mult_op = SoftmaxMaskOp(name=f"{name}/elementwise_mult_op", logger=logger)
        self.bias_add_op1.merge_residue_into_bias = False
        self._layer_flow = self._build_flow()  # TODO: a bit hacky

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = super().get_default_params()
        defaults["is_softmax_mask"] = True
        return dict(defaults)

    def enforce_io_encoding(self, training=False, **kwargs):
        self.set_output_scale(self.input_scales[0], 0)
        self.set_output_zero_point(self.input_zero_points[0], 0)
