from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import BaseQuantElement, MACDataQuantElement
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImportParamConfigMismatch,
    AccelerasPrematureQuantOperation,
)


@dataclass
class MockConvOpWeightsLossy(BaseWeightLossyElements):
    kernel: BaseLossyElement


class MockConvOp(BaseAtomicOp):
    """
    This class emulates the mock kernel operation
    """

    weight_lossy_elements: MockConvOpWeightsLossy
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = MockConvOpWeightsLossy(
            kernel=IdentityElement(name=f"{self.full_name}/ie:mock_conv_op")
        )
        self.kernel_size = (1, 1)
        self.kernel = 1.0
        self.kernel_scale = 1.0
        self.pre_acc_shift = 0
        self.kernel_zero_point = 0
        self.shift_delta = 0

    @property
    def kernel_shape(self):
        input_channels = int(self.input_shape[-1])
        return self.kernel_size + (1, input_channels)

    def get_quant_kernel(self, training=False):
        quantized_value = tf.cast(self.kernel / self.kernel_scale, tf.float32)
        return self.weight_lossy_elements.kernel(quantized_value, training=training)

    def create_weight_quant_element(self, kernel_bits=8, signed=True):
        self.weight_lossy_elements = MockConvOpWeightsLossy(
            kernel=MACDataQuantElement(bits=kernel_bits, signed=signed, name=f"{self.full_name}/qe:kernel"),
        )

    def create_hw_params(self, kernel_value=None, factor=None, pre_acc_shift=None, **kwargs):
        if not np.logical_xor(kernel_value is None, factor is None):
            raise ValueError(f"Only the kernel_value (={kernel_value}) or the factor (={factor}) can be applied.")
        weight_bits = self.weight_lossy_elements.kernel.bits
        if pre_acc_shift is not None:
            self.pre_acc_shift = pre_acc_shift
        elif weight_bits == 15 or weight_bits == 16:
            # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
            self.pre_acc_shift = tf.constant(0)
        else:
            self.pre_acc_shift = tf.constant(1)
        if kernel_value is not None:
            quantized_kernel_candidate = kernel_value
            self.kernel_scale = np.abs(self.kernel) / quantized_kernel_candidate
        if factor is not None:
            self.kernel_scale = factor * 2 ** (-1.0 * self.pre_acc_shift)

        self.kernel_zero_point = 0
        self.shift_delta = 0
        self.enforce_encoding()

    def create_scales(self, kernel_value=None, factor=None, pre_acc_shift=None):
        if not np.logical_xor(kernel_value is None, factor is None):
            raise ValueError(f"Only the kernel_value (={kernel_value}) or the factor (={factor}) can be applied.")
        weight_bits = self.weight_lossy_elements.kernel.bits
        if pre_acc_shift is not None:
            self.pre_acc_shift = pre_acc_shift
        elif weight_bits == 15:
            # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
            self.pre_acc_shift = 0
        else:
            self.pre_acc_shift = 1
        if kernel_value is not None:
            quantized_kernel_candidate = kernel_value
            self.kernel_scale = np.abs(self.kernel) / quantized_kernel_candidate
        if factor is not None:
            self.kernel_scale = factor * 2 ** (-1.0 * self.pre_acc_shift)

        self.kernel_zero_point = 0
        self.shift_delta = 0
        self.enforce_encoding()

    def enforce_encoding(self, training=False):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        shift_val = tf.cast(
            tf.pow(tf.constant(2.0, dtype=self.FLOAT_TYPE_TF), tf.cast(self.pre_acc_shift, dtype=self.FLOAT_TYPE_TF)),
            self.input_scale.dtype,
        )
        self.output_scale = self.input_scale * self.kernel_scale * shift_val
        quant_kernel_dtype = self.get_quant_kernel(training=training).dtype
        self.output_zero_point = (
            tf.cast(self.input_zero_point, quant_kernel_dtype)
            * self.get_quant_kernel(training=training)
            / tf.cast(2**self.pre_acc_shift, quant_kernel_dtype)
        )

    def _compute_output_shape(self, input_shape):
        return input_shape

    def export_independent_params(self):
        return {
            "kernel_scale": np.array(self.kernel_scale, np.float32),
            "kernel_zero_point": np.array(self.kernel_zero_point, np.float32),
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "weight_bits": np.array(self.weight_lossy_elements.kernel.bits, np.float32),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.kernel, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.kernel.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("kernel_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.kernel_scale = params["kernel_scale"]
        self.kernel_zero_point = params["kernel_zero_point"]

    def export_quant_weights(self):
        size_kernel = self.kernel_size + (self.kernel_shape[-1], 1)
        kernel_q = np.ones(size_kernel) * self.get_quant_kernel().numpy()
        return {
            "quant_kernel": np.float32(kernel_q),
        }

    def export_hw_params(self):
        w_type = np.int8 if self.weight_lossy_elements.kernel.bits <= 8 else np.int16
        size_kernel = self.kernel_size + (self.kernel_shape[-1], 1)
        kernel_q = (np.ones(size_kernel) * self.get_quant_kernel().numpy()).astype(w_type)
        return {
            "kernel": kernel_q,
            "zp_kernel": np.array(self.kernel_zero_point, np.int32),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
        }

    def export_weights(self):
        return {"kernel": np.array(self.kernel)}

    def import_weights(self, layer_params):
        kernel = layer_params.get("kernel")
        if kernel is not None:
            self.kernel = kernel

    def call_native(self, inputs, **kwargs):
        return inputs[0] * self.kernel

    def call_hw_sim(self, inputs, training=False, **kwargs):
        kernel_q = self.get_quant_kernel(training=training)
        multiplication_res = inputs[0] * kernel_q / tf.cast(2 ** (self.pre_acc_shift), kernel_q.dtype)
        return multiplication_res

    def call_bit_exact(self, inputs, training=False, **kwargs):
        kernel_q = tf.cast(self.get_quant_kernel(training=training), self.INT_TYPE_TF)
        multiplication_res = self.bankers_round_with_shift(inputs[0] * kernel_q, self.pre_acc_shift, signed=True)
        return self.hw_simulation_by_lossy_element(multiplication_res, self.output_lossy_element)

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(
            f"{self.full_name}/kernel_scale:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.kernel_scale),
        )
        flow.add_encoding(
            f"{self.full_name}/mac_shift:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.pre_acc_shift),
            quant=True,
            quant_min=1.0,
            quant_max=4.0,
        )

    def define_constraints(self, enc):
        super().define_constraints(enc)
        # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
        if self.weight_lossy_elements.kernel.bits == 15 or self.weight_lossy_elements.kernel.bits == 16:
            enc.identity(f"{self.full_name}/mac_shift:0", np.float32(0.0))

        # compute output_scale
        enc.mul(enc.dummy(0), f"{self.full_name}/input_scale:0", f"{self.full_name}/kernel_scale:0", inverse=True)
        enc.shift(f"{self.full_name}/output_scale:0", enc.dummy(0), f"{self.full_name}/mac_shift:0")

        # compute output_zero_point
        enc.div(enc.dummy(1), self.kernel, f"{self.full_name}/kernel_scale:0")
        enc.cast(enc.dummy(2), enc.dummy(1))
        enc.lossy_element(enc.dummy("kernel_q"), enc.dummy(2), self.weight_lossy_elements.kernel)
        enc.mul(enc.dummy(3), f"{self.full_name}/input_zero_point:0", enc.dummy("kernel_q"))
        enc.shift(enc.dummy(3), f"{self.full_name}/output_zero_point:0", f"{self.full_name}/mac_shift:0")

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/kernel_scale:0", self.kernel_scale)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.kernel_scale = encodings[f"{self.full_name}/kernel_scale:0"]
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
