import copy
from collections import OrderedDict

import tensorflow as tf
from tensorflow.keras.utils import deserialize_keras_object

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import update_nested
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasMode, PrecisionMode
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException
from hailo_model_optimization.acceleras.utils.logger import default_logger


class BaseAccelerasLayer(tf.keras.layers.Layer):
    num_outputs = 1
    mandatory_hn_fields = ["type", "input_shapes", "output_shapes"]

    def __init__(self, name, logger=None, **kwargs):
        self._logger = logger if logger is not None else default_logger()
        self.full_name = name
        keras_name = name.split("/")[-1]
        super().__init__(name=keras_name, **kwargs)
        self._hn_element = OrderedDict()
        self._in_emulation_graph = True

    @property
    def in_emulation_graph(self):
        return self._in_emulation_graph

    @property
    def hn_element(self):
        try:
            # For Keras >= 3
            return deserialize_keras_object(self._hn_element)
        except Exception:
            return self._hn_element

    def resolve_output_index(self, output_index):
        """
        Returns the data output index of a layer based on the output edge index.
        If the model has a single output, the index will be 0, otherwise the index will be the actual index.
        """
        return output_index % self.num_outputs

    def finalize_from_hn(self, hn_element):
        self._hn_element = copy.deepcopy(hn_element)

    def to_hn(self, out_degree=None):
        self._verify_exportable()
        hn = dict()
        return update_nested(hn, self.hn_element)

    def import_acceleras(self, params):
        pass

    def _verify_exportable(self, hn_element=None):
        if hn_element is None:
            hn_element = self.hn_element
        if (
            any([x not in hn_element or hn_element[x] is None for x in self.mandatory_hn_fields])
            or self.full_name is None
        ):
            raise AccelerasException(f"{self.full_name} Missing layer name or {self.mandatory_hn_fields}")

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a8_w8

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.single_scale_decomposition

    @classmethod
    def get_default_quantization_groups(cls):
        return 1

    @classmethod
    def get_default_precision_config(cls):
        return LayerPrecisionConfig(
            precision_mode=cls.get_default_precision_mode(),
            bias_mode=cls.get_default_bias_mode(),
            quantization_groups=cls.get_default_quantization_groups(),
        )

    def is_jit_compile_supported(self, training=False):
        return True
