import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp

MU_SHIFT = 2


class NormFinalOp(BaseAtomicOp):
    """
    Emulate NormOp operation
    """

    num_inputs = 3
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._mult_shift = 0

    def _compute_norm_native(self, x, mu, inv_sqrt):
        diff = tf.subtract(x, mu)
        return tf.multiply(diff, inv_sqrt)

    def call_native(self, inputs, **kwargs):
        return self._compute_norm_native(inputs[0], inputs[1], inputs[2])

    def call_hw_sim(self, inputs, **kwargs):
        return self._compute_norm_native(inputs[0] * 2**self._mult_shift, inputs[1], inputs[2])

    def is_differentiable(self) -> bool:
        return False

    def export_quant_weights(self):
        return {}

    def export_weights(self):
        return {}

    def enforce_encoding(self, *args, **kwargs):
        self.output_scale = self.input_scales[1] * self.input_scales[2]

    def create_weight_quant_element(self):
        pass

    def create_hw_params(self, shift=0):
        self._mult_shift = shift
        self.enforce_encoding()

    def export_independent_params(self):
        return {
            "mult_shift": np.float32(self._mult_shift),
        }

    def import_independent_params(self, params):
        self._mult_shift = np.float32(params["mult_shift"])
