import tensorflow as tf

from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_flow import EncodingInferenceFlowGraph, InferenceNodeType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasEncodingError


class PosConstraint(tf.keras.constraints.Constraint):
    """Constrains weight tensors to be positive."""

    def __init__(self, epsilon=1e-9):
        self.epsilon = epsilon

    def __call__(self, w):
        return tf.maximum(w, self.epsilon)

    def get_config(self):
        return {"epsilon": self.epsilon}


class TensorInitializer(tf.keras.initializers.Initializer):
    """Initializer that init a weight using a tensor."""

    def __init__(self, tensor, eps=0.0):
        self.tensor = tensor
        self.eps = eps

    def __call__(self, shape, dtype=None):
        orig_shape = self.tensor.shape if hasattr(self.tensor, "shape") else ()
        if orig_shape == shape:
            tensor = self.tensor
        elif shape == ():
            max_dif = tf.reduce_max(tf.abs(self.tensor - self.tensor[0]) / self.tensor)
            if self.eps < max_dif:
                raise AccelerasEncodingError(
                    f"TensorInitializer received tensor with more then one unique elements, "
                    f"but expected to initialize a scalar (Max elements diff {max_dif}).",
                )
            tensor = self.tensor[0]
        elif orig_shape == ():
            tensor = tf.ones(shape) * self.tensor
        else:
            tensor = tf.reshape(self.tensor, shape)
        if dtype is not None:
            tensor = tf.cast(tensor, dtype)
        return tensor

    def get_config(self):
        return {"tensor": self.tensor, "eps": self.eps}


class HailoModelEncoding(tf.keras.layers.Layer):
    def __init__(self, name, encoding_inference_flow: EncodingInferenceFlowGraph, **kwargs):
        self.full_name = name
        keras_name = name.split("/")[-1]
        super().__init__(name=keras_name, **kwargs)
        self.flow = encoding_inference_flow

        self.independent_vars = dict()
        self.dependant_tensors = dict()

    @staticmethod
    def _get_default_initializer(encoding_type: EncodingType):
        if encoding_type == EncodingType.Scale:
            return tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.05)
        if encoding_type == EncodingType.ZeroPoint:
            return tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05)

    @staticmethod
    def _get_default_constraint(encoding_type: EncodingType):
        if encoding_type == EncodingType.Scale:
            return PosConstraint()
        if encoding_type == EncodingType.ZeroPoint:
            return None

    def build(self, input_shape):
        for name in self.flow.independent_nodes:
            encoding = self.flow.get_encoding(name)
            initializer = (
                encoding.initializer
                if encoding.initializer is not None
                else self._get_default_initializer(encoding.encoding_type)
            )
            constraint = (
                encoding.constraint
                if encoding.constraint is not None
                else self._get_default_constraint(encoding.encoding_type)
            )
            regularizer = encoding.regularizer
            trainable = True
            self.independent_vars[name] = self.add_weight(
                name=f"{name}",
                shape=encoding.shape,
                initializer=initializer,
                constraint=constraint,
                regularizer=regularizer,
                trainable=trainable,
            )
        self._build_input_shape = input_shape
        self.built = True

    def call(self, inputs, return_hidden=False, training=False):
        for name in self.flow.toposort():
            if self.flow.get_node_type(name) == InferenceNodeType.INDEPENDENT:
                self.dependant_tensors[name] = self.independent_vars[name]
            else:
                func = self.flow.get_func(name)
                args = [self.dependant_tensors[inp] for inp in self.flow.inputs_sorted(name)]
                kwargs = {}
                func_str = self.flow.get_func_string(name)
                if (
                    func_str.startswith("lossy(")
                    or func_str.startswith("get_pieces_encoding(")
                    or func_str.startswith("output_zp_callback(")
                ):
                    kwargs["training"] = training
                self.dependant_tensors[name] = func(*args, **kwargs)
        if return_hidden:
            return self.dependant_tensors
        return {k: v for k, v in self.dependant_tensors.items() if not self.flow.nodes[k]["hidden"]}
