"""
This module defines all the basics for simulating everythin LOSSY (aka, bits-reducing),
encompassing the following OPS:
  - Rounding & Clipping (aka quantization, aka grid-pinning)
  - TODO Pruning, Vector-Quantization, etc.

and the following FEATURES:
- Diffentiability solutions (currently, STE)
- Parameterized op construction (e.g. variable bits quantizer) via config

"""

import uuid
from abc import ABC, abstractmethod

import tensorflow as tf

from hailo_model_optimization.acceleras.utils.flow_state_utils import LossyState


class BaseLossyElement(tf.keras.layers.Layer, ABC):
    def __init__(self, name=None):
        """
        each bit reducer is by default lossless, untill we use the set lossy API. only on enable we can change it.
        """
        self.is_lossless = True
        # Add random ID to make class hashable for keras.
        # TODO: use the op name + location (in / out / weight + index)
        name = name if name is not None else f"qe:{uuid.uuid4()!s}"

        self.full_name = name
        keras_name = name.split("/")[-1]
        super().__init__(name=keras_name)

    def __call__(self, inp, training=False):
        if self.is_lossless:
            return tf.identity(inp)
        else:
            return self.lossy_call(inp, training)

    @abstractmethod
    def lossy_call(self, inp, training=False):
        """Subclasses must implement this"""

    def enable(self):
        """
        set the bit reducer to be lossy. this is the only way to do so.
        """
        self.is_lossless = False

    def disable(self):
        """
        set the bit reducer to be lossless. this is the only way to do so.
        """
        self.is_lossless = True

    def get_state(self):
        return self.is_lossless

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self._is_eq(other)
        return False

    @abstractmethod
    def _is_eq(self, other):
        return

    def __str__(self):
        return f"enabled={not self.is_lossless}"

    def __hash__(self):
        return hash(self.full_name)

    def export_flow_state(self) -> LossyState:
        """
        export the flow parameters of the lossy elements, e.g is it lossy or lossless.
        """
        return LossyState(
            full_name=self.full_name,
            lossy_class_type=self.__class__.__name__,
            is_lossless=self.is_lossless,
        )

    def import_flow_state(self, lossy_state: LossyState) -> None:
        """
        import the flow parameters of the lossy elements, e.g is it lossy or lossless.
        """
        self.is_lossless = lossy_state.is_lossless
