import numpy as np

from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.utils.flow_state_utils import (
    LossyState,
)


class PruneElement(BaseLossyElement):
    """
    This is a workaround (for now) to inject clipping to the "native" graph
    """

    def __init__(self, mask, **kwargs):
        """
        the mask is a binary mask at the size of the input to the element
        if 0 the value will be forced to zero
        """
        super().__init__(**kwargs)
        self.mask = mask

    def lossy_call(self, inp, training=False):
        return inp * self.mask

    def _is_eq(self, other):
        return np.all(self.mask == other.mask)

    def export_flow_state(self) -> LossyState:
        flow_state = super().export_flow_state()
        flow_state.lossy_dict_kwgs = {"mask": self.mask}
        return flow_state

    def import_flow_state(self, lossy_state: LossyState) -> None:
        super().import_flow_state(lossy_state)
        self.mask = lossy_state.lossy_dict_kwgs["mask"]
