import tensorflow as tf

from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.utils.flow_state_utils import LossyState


class ClipElement(BaseLossyElement):
    """
    This is a workaround (for now) to inject clipping to the "native" graph
    """

    def __init__(self, min_vals, max_vals, **kwargs):
        super().__init__(**kwargs)
        self.min_vals = min_vals  # now this is an int, in the future it mayu be a vector channelwise
        self.max_vals = max_vals  # now this is an int, in the future it mayu be a vector channelwise
        self.enable()  # is set lossy by default

    def lossy_call(self, inp, training=False):
        return tf.clip_by_value(inp, self.min_vals, self.max_vals)

    def _is_eq(self, other):
        return (self.min_vals == other.min_vals) and (self.max_vals == other.max_vals)

    def export_flow_state(self) -> LossyState:
        flow_state = super().export_flow_state()
        flow_state.lossy_dict_kwgs = {"min_vals": self.min_vals, "max_vals": self.max_vals}
        return flow_state

    def import_flow_state(self, lossy_state: LossyState) -> None:
        super().import_flow_state()
        self.min_vals = lossy_state.lossy_dict_kwgs["min_vals"]
        self.max_vals = lossy_state.lossy_dict_kwgs["max_vals"]
