from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.atomic_ops.inverse_lut_calc import InverseFuncLUTCalculator
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ACTIVATION_CLIP_BITS_HAILO_LAYER_NORM,
    EmulationType,
)
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams

MANTISA_BITS = 16


@dataclass
class NormWeightsLossy(BaseWeightLossyElements):
    mu_clip: BaseLossyElement
    clip_before_apu: BaseLossyElement


DEBUG = True


class FinalizeNormOp(BaseAtomicOp):
    """
    Emulate NormOp operation
    """

    num_inputs = 3
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        self.f_out = None
        self.is_lossless = True
        self._shift_to_plus = 0
        self._mlt_cfg = None
        self._shift_cfg = 0
        self.lut_inv_sqrt = InverseFuncLUTCalculator(sqrt=True)

        self.weight_lossy_elements = NormWeightsLossy(
            mu_clip=IdentityElement(name=f"{self.full_name}/ie:mu_clip"),
            clip_before_apu=IdentityElement(name=f"{self.full_name}/ie:clip_before_apu"),
        )

        self.set_type_emulation(EmulationType.DOUBLE)

    def _build(self, input_shape):
        if self._mlt_cfg is None:
            self._mlt_cfg = 1.0 / input_shape[0][-1]

    def _compute_output_shape(self, input_shape):
        self.f_out = input_shape[0][-1]
        return input_shape[0]

    def disable_lossy(self, **kwargs):
        super().disable_lossy(**kwargs)
        self.is_lossless = True

    def enable_lossy(self, **kwargs):
        super().enable_lossy(**kwargs)
        self.is_lossless = False

    def get_div_cfg(self, factor):
        bits = MANTISA_BITS
        mlt_cfg, exp = np.frexp(factor)  # mlt_cfg*2** exp
        shift_cfg = min(63, bits - exp)
        mlt_cfg = int(np.round(mlt_cfg * 2**bits))
        return mlt_cfg, shift_cfg

    def _get_scale(self, rmax=2, rmin=0.5, bins=255):
        scale = (rmax - rmin) / bins
        zp = -rmin / scale
        return zp, scale

    @property
    def bit_exact_supported(self) -> bool:
        return True

    @property
    def factor(self):
        return 1 / self.f_out

    def _calc_shift_to_plus(self):
        output_scale = self.scale_x * self.scale_out_inv_sqrt
        max_value = np.maximum(
            np.max(np.abs(self.get_output_stats(0).max)),
            np.max(np.abs(self.get_output_stats(0).min)),
        )
        current_max = self.output_lossy_element.max_value * output_scale
        return np.floor(np.log2(current_max / max_value) / 2 - 1) * 2

    def create_hw_params(self, rms_norm=False, **kwargs):
        if rms_norm:
            mlt, shift = 0, 0
        else:
            ratio = self.scale_x_sum / self.scale_x * (1.0 / self.f_out)
            mlt, shift = self.get_div_cfg(ratio)
        self._mlt_cfg = mlt
        self._shift_cfg = shift
        self._shift_to_plus = self._calc_shift_to_plus()
        self.enforce_encoding()
        self.is_lossless = False

    @property
    def _qunat_epsilon(self):
        epsil = self.epsilon * (self.f_out**2) / self.scale_x2_sum
        epsil = tf.cast(epsil, self.FLOAT_TYPE_TF)
        epsil_1_clipped = tf.clip_by_value(tf.round(epsil), 1, 2**32 - 1)
        val = tf.cast(epsil_1_clipped, tf.int64)
        return val

    def layer_norm_logic(self, inputs):
        inputs_zp = inputs[0]
        sum_x = inputs[1]
        sum_x2_n = inputs[2]
        mu = self._compute_mu(sum_x)
        variance = self._compute_variance(sum_x2_n, sum_x, self._qunat_epsilon)
        inv_sqrt_mantisa, inv_sqrt_exponent = self._get_mantissa_exp_vals(variance, bit_exact=False)
        inv_sqrt_exponent = inv_sqrt_exponent - self._shift_to_plus
        norm_value = self._compute_norm(inputs_zp, mu, inv_sqrt_mantisa, inv_sqrt_exponent)

        norm_value = self.weight_lossy_elements.clip_before_apu(norm_value)

        self.mu = mu
        self._variance = variance
        self.inv_sqrt_mantisa = inv_sqrt_mantisa
        self.inv_sqrt_exponent = inv_sqrt_exponent

        return norm_value

    def _compute_mu_bit_exact(self, sum_x):
        self._verify_data_dtype(sum_x, 20, True, "sum_x")
        self._verify_data_dtype(self._mlt_cfg, 16, False, "_mlt_cfg")

        mean_x_tag = tf.multiply(sum_x, self._mlt_cfg)
        self._verify_data_dtype(mean_x_tag, 36, True, "mean_x_tag")

        # mean_x = self.bankers_round_with_shift(mean_x_tag, self._shift_cfg)
        mean_x = self.signed_shift_bankers_rounding(mean_x_tag, self._shift_cfg)

        bits = self.weight_lossy_elements.mu_clip.bits - 1
        mean_x_clipped = tf.clip_by_value(mean_x, -(2**bits), 2**bits - 1)

        # mean_x_clipped = tf.cast(mean_x_clipped, self.INT_TYPE_TF)
        self._verify_data_dtype(mean_x_clipped, 16, True, "mean_x_clipped")

        return mean_x_clipped

    def _compute_mu(self, sum_x):
        # sum_x = tf.cast(sum_x, self.FLOAT_TYPE_TF)
        mean_x_tag = tf.multiply(sum_x, self._mlt_cfg)
        mean_x = mean_x_tag / (2**self._shift_cfg)
        mean_x_clipped = self.weight_lossy_elements.mu_clip(mean_x)
        return mean_x_clipped

    def _compute_variance(self, sum_x2_n, sum_x, epsilon_quant):
        epsilon_quant = tf.cast(epsilon_quant, self.FLOAT_TYPE_TF)
        # sum_x2_n = tf.cast(sum_x2_n, self.FLOAT_TYPE_TF)
        # sum_x = tf.cast(sum_x, self.FLOAT_TYPE_TF)
        sum_x_sqaure = tf.square(sum_x)
        var = sum_x2_n - sum_x_sqaure + epsilon_quant
        return var

    def _compute_variance_bit_exact(self, sum_x2_n, sum_x, epsilon_quant):
        self._verify_data_dtype(epsilon_quant, 32, False, "epsilon_quant")
        sum_x_sqaure = tf.square(sum_x)
        self._verify_data_dtype(sum_x_sqaure, 40, False, "sum_x_sqaure")
        var1 = tf.subtract(sum_x2_n, sum_x_sqaure)
        self._verify_data_dtype(var1, 56, False, "var")
        var = sum_x2_n - sum_x_sqaure + epsilon_quant
        self._verify_data_dtype(var, 56, False, "var")
        return var

    def _compute_norm_bit_exact(self, inputs, mu, inv_sqrt_mantisa, inv_sqrt_exponent):
        self._verify_data_dtype(inputs, 16, True, "inputs")
        self._verify_data_dtype(mu, 16, True, "mu")
        self._verify_data_dtype(inv_sqrt_mantisa, 16, False, "inv_sqrt_mantisa")
        self._verify_data_dtype(inv_sqrt_exponent, 4, True, "inv_sqrt_exponent")

        diff = tf.subtract(inputs, mu)
        self._verify_data_dtype(diff, 16, True, "diff")

        diff = tf.cast(diff, self.FLOAT_TYPE_TF)
        mul = tf.multiply(diff, inv_sqrt_mantisa)

        self._verify_data_dtype(mul, 32, True, "mul")

        # inv_sqrt_exponent = tf.cast(inv_sqrt_exponent, self.FLOAT_TYPE_TF)
        # mul = tf.cast(mul, self.FLOAT_TYPE_TF)
        mul_after_shift = tf.floor(mul / (2.0**inv_sqrt_exponent))
        mul_after_shift = tf.cast(mul_after_shift, self.INT_TYPE_TF)

        self._verify_data_dtype(mul_after_shift, 34, True, "mul_after_shift")
        return mul_after_shift

    def _compute_norm(self, inputs, mu, inv_sqrt_mantisa, inv_sqrt_exponent):
        # mu = tf.cast(mu, self.FLOAT_TYPE_TF)
        # inputs = tf.cast(inputs, self.FLOAT_TYPE_TF)

        diff = tf.subtract(inputs, mu)
        # diff = tf.cast(diff, self.FLOAT_TYPE_TF)

        # inv_sqrt_mantisa = tf.cast(inv_sqrt_mantisa, self.FLOAT_TYPE_TF)
        mul = tf.multiply(diff, inv_sqrt_mantisa)

        mul_after_shift = tf.round(mul / (2**inv_sqrt_exponent))
        return mul_after_shift

    ## all lut functions
    @property
    def lut_inv_sqrt_scale_out(self):
        return self.lut_inv_sqrt.lut_scale_out

    def _get_mantissa_exp_vals(self, sum_e_x, bit_exact=False):
        return self.lut_inv_sqrt.lut_vals_helper(
            sum_e_x,
            float_type=self.FLOAT_TYPE_TF,
            int_type=self.INT_TYPE_TF,
            bit_exact=False,
            is_lossless=self.is_lossless,
        )

    def export_quant_weights(self):
        return {"epsilon_quant": self._qunat_epsilon.numpy()}

    def call_native(self, inputs, **kwargs):
        inputs_x = inputs[0]
        sum_x = inputs[1]
        sum_x2_n = inputs[2]

        sum_x_2 = tf.square(sum_x)
        f_out = tf.cast(self.f_out, self.FLOAT_TYPE_TF)

        mu = sum_x / f_out
        numerator = tf.maximum(sum_x2_n - sum_x_2, 0)
        var = numerator / (f_out**2)
        inv_sqrt = 1.0 / (tf.sqrt(var + self.epsilon))
        diff = inputs_x - mu

        res1 = tf.multiply(diff, inv_sqrt)
        return res1

    def call_hw_sim(self, inputs, **kwargs):
        if self.is_lossless:
            inputs_native = self._decode_inputs(inputs)
            result = self.call_native(inputs_native, **kwargs)
            result = result if isinstance(result, list) else [result]
            res = self._encode_outputs(result)
            return res

        return self.layer_norm_logic(inputs)

    def call_bit_exact(self, inputs, **kwargs):
        inputs_zp = inputs[0]
        sum_x = inputs[1]
        sum_x2_n = inputs[2]

        mu = self._compute_mu_bit_exact(sum_x)
        variance = self._compute_variance_bit_exact(sum_x2_n, sum_x, self._qunat_epsilon)
        inv_sqrt_mantisa, inv_sqrt_exponent = self._get_mantissa_exp_vals(variance, bit_exact=True)
        inv_sqrt_exponent = inv_sqrt_exponent - self._shift_to_plus
        norm_value = self._compute_norm_bit_exact(inputs_zp, mu, inv_sqrt_mantisa, inv_sqrt_exponent)
        norm_value = self.hw_simulation_by_lossy_element(norm_value, self.weight_lossy_elements.clip_before_apu)

        self.mu = mu
        self._variance = variance
        self.inv_sqrt_mantisa = inv_sqrt_mantisa
        self.inv_sqrt_exponent = inv_sqrt_exponent

        return norm_value

    def is_differentiable(self) -> bool:
        return False

    def create_weight_quant_element(self, mu_bits):
        self.weight_lossy_elements = NormWeightsLossy(
            mu_clip=QuantElement(signed=True, bits=mu_bits, wraparound=False, name=f"{self.full_name}/qe:mu_clip"),
            clip_before_apu=QuantElement(
                signed=True,
                bits=ACTIVATION_CLIP_BITS_HAILO_LAYER_NORM,
                wraparound=False,
                name=f"{self.full_name}/qe:clip_before_apu",
            ),
        )

    def export_independent_params(self):
        return {
            "mlt_cfg": np.float32(self._mlt_cfg),
            "shift_cfg": np.float32(self._shift_cfg),
            "shift_to_plus": np.float32(self._shift_to_plus),
        }

    def export_hw_params(self):
        lut_table = self.lut_inv_sqrt.get_hw_lut()

        return {
            "lut_table": lut_table,
            "mult_mu": np.uint16(self._mlt_cfg),
            "shift_mu": np.int8(self._shift_cfg),
            "epsilon_quant": np.uint8(self._qunat_epsilon),
            "shift_to_plus": np.int8(self._shift_to_plus),
        }

    def import_independent_params(self, params):
        self._mlt_cfg = params["mlt_cfg"]
        self._shift_cfg = params["shift_cfg"]
        self._shift_to_plus = params["shift_to_plus"]

    def import_weights(self, layer_params: LayerParams, **kwargs):
        epsilon = layer_params.get("epsilon", 1e-06)
        self.epsilon = epsilon

    def export_weights(self):
        return {"epsilon": self.epsilon}

    @property
    def scale_x(self):
        scale = self.input_scales[0]
        if np.array(scale).shape != ():
            scale = scale[0]
        return scale

    @property
    def scale_x_sum(self):
        scale = self.input_scales[1]
        if np.array(scale).shape != ():
            scale = scale[0]
        return scale

    @property
    def scale_x2_sum(self):
        scale = self.input_scales[2]
        if np.array(scale).shape != ():
            scale = scale[0]
        return scale

    @property
    def scale_out_inv_sqrt(self):
        scale_to_calc = tf.cast(self.lut_inv_sqrt.function(self.scale_x2_sum), self.FLOAT_TYPE_TF)
        lut_inv_sqrt_scale_out = tf.cast(self.lut_inv_sqrt.lut_scale_out, self.FLOAT_TYPE_TF)
        f_out = tf.cast(self.f_out, self.FLOAT_TYPE_TF)
        return lut_inv_sqrt_scale_out * scale_to_calc * f_out

    def enforce_encoding(self, *args, **kwargs):
        output_channels = self.output_shape[-1]
        self.output_scale = np.repeat(
            self.scale_x * self.scale_out_inv_sqrt / (2**self._shift_to_plus), output_channels
        )
        self.output_zero_point = np.array(0, np.float32)

    def import_flow_state(self, atomic_state: AtomicOpState):
        super().import_flow_state(atomic_state)
        self.is_lossless = atomic_state.aops_dict_kwgs["is_lossless"]

    def export_flow_state(self) -> AtomicOpState:
        aops_state = super().export_flow_state()
        aops_state.aops_dict_kwgs["is_lossless"] = self.is_lossless
        return aops_state
