from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.atomic_ops.initializers import Zeros
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import BaseQuantElement, BiasQuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import MAX_NUM_REPEATS
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasBiaseDecompositionError,
    AccelerasDecompositionError,
    AccelerasImportParamConfigMismatch,
    AccelerasPrematureQuantOperation,
)


@dataclass
class BiasWeightsLossy(BaseWeightLossyElements):
    bias_decompose: BaseLossyElement


class AddBiasOp(BaseAtomicOp):
    """
    Atomic op that describes bias addition operation.

    Implements both full-precision bias operation, and numerized behavior in the MAC unit

    Bias addition atomic operation.

    Args:
        bias_initializer: keras initializer bias, if None if be initialized to zeros
        axis: Axis for addition. The addition will be broadcast to all the others axises
        trainable: If true the biases will be trainable

    Attributes:
        pretrained_bias: original Bias

    """

    weight_lossy_elements: BiasWeightsLossy

    num_inputs = 1
    num_outputs = 1

    # Debug Tensors, created after build / call
    pretrained_bias: tf.Tensor
    _bias_at_accumulator: tf.Tensor
    _bias_at_accumulator_pre_merge: tf.Tensor
    _bias_numerized: tf.Tensor
    _bias_after_wrap: tf.Tensor

    def __init__(
        self,
        name,
        bias_initializer=None,
        axis=(-1,),
        trainable=True,
        is_correctable=True,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = BiasWeightsLossy(
            bias_decompose=IdentityElement(name=f"{self.full_name}/ie:bias_decompose"),
        )
        self.trainable = trainable
        self.merge_residue_into_bias = True
        self.bias_initializer = bias_initializer if bias_initializer is not None else Zeros()
        self._max_feed_repeat = MAX_NUM_REPEATS
        self._bias = None
        self.axis = axis
        self.is_correctable = is_correctable
        self._precision_split_zp = False
        # Note that when if the bias wrapper is lossless we can get more than
        # 16 bit in accumulator and this is ok

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "name": self.full_name,
                "bias_initializer": self.bias_initializer,
                "axis": self.axis,
                "trainable": self.trainable,
                "is_correctable": self.is_correctable,
                "max_feed_repeat": self._max_feed_repeat,
                "precision_split_zp": self._precision_split_zp,
                "fully_native": self.fully_native,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        valid_kwargs = {
            "bias_initializer": config.pop("bias_initializer"),
            "name": config.pop("name"),
            "axis": config.pop("axis"),
            "trainable": config.pop("trainable"),
            "is_correctable": config.pop("is_correctable"),
            "fully_native": config.pop("fully_native"),
        }

        instance = cls(**valid_kwargs)

        return instance

    @property
    def bias(self):
        return self._bias

    @property
    def pre_acc_shift(self):
        return self.weight_lossy_elements.bias_decompose.pre_acc_shift

    @pre_acc_shift.setter
    def pre_acc_shift(self, value):
        self.weight_lossy_elements.bias_decompose.pre_acc_shift = value

    @property
    def residue(self):
        return -self.input_zero_points[0] * 2**self.pre_acc_shift

    @property
    def bit_exact_supported(self) -> bool:
        return True

    def create_weight_quant_element(self, kernel_bits, signed, num_decomposition=0, accumulator_bits=None):
        if accumulator_bits is None:
            accumulator_bits = self.input_lossy_elements[0].bits

        if accumulator_bits == 32:
            data_bits = 15
        else:
            data_bits = accumulator_bits // 2
        weight_bits = accumulator_bits // 2
        decompose_element = BiasQuantElement(
            accumulator_bits,
            data_bits,
            weight_bits,
            num_decomposition,
            self._max_feed_repeat,
            kernel_bits,
            signed,
            name=f"{self.full_name}/qe:bias",
        )
        self.weight_lossy_elements = BiasWeightsLossy(
            bias_decompose=decompose_element,
        )

    def import_weights(self, bias, **kwargs):
        self.pretrained_bias = bias
        if self.built:
            self._bias.assign(tf.cast(bias, self.FLOAT_TYPE_TF))
        else:
            self._bias = tf.constant(tf.cast(bias, self.FLOAT_TYPE_TF))
        self.bias_initializer = tf.keras.initializers.Constant(bias)

    def _build(self, input_shape):
        if len(input_shape) == 1:
            shape = input_shape[self.axis[0]]
        else:
            shape = [input_shape[ax] for ax in self.axis]
        self._bias = self.add_weight(
            shape=shape,
            trainable=self.trainable,
            initializer=self.bias_initializer,
            name="bias",
        )

    def _compute_output_shape(self, input_shapes):
        return input_shapes

    def create_hw_params(self, *args, **kwargs):
        """
        Handles decomposition if needed.
        Assumes that enforce_encoding has been run,
        specifically all "fold into bias" stuff.
        """
        bias_at_accumulator = self.encode_bias()
        try:
            self.weight_lossy_elements.bias_decompose.decompose(bias_at_accumulator)
        except AccelerasDecompositionError as e:
            raise AccelerasBiaseDecompositionError(self.full_name, e)

    def enforce_encoding(self, forward=True):
        """
        Implements the following pieces of quantization scheme:
        A. The "folding" of asymmetry of accumulator encoding (embodied in the input zero point)
            into the [desired] bias, so input to APU is just scaled not skewed.
        B. ["pre"]wrap-around of the numeric bias candidate (the one we'd like to add to accumulator),
            following the current MAC behavior at addition overflow.

        NOTE:  For the "decomposed" cases, the #feed-repeats is NOT updated here,
               computed once in the "create_hw_params" stage by a non-differentiable method.
               This creates an assumption that the changes during fine-tuning algos
                   won't be large enough to the current feed-repeat unworkable.
        """
        if forward:
            self.output_scale = self.input_scale
            self.output_zero_point = self.output_zero_point if self.merge_residue_into_bias else self.input_zero_point
        else:
            self.input_scale = self.output_scale
            self.input_zero_point = self.input_zero_point if self.merge_residue_into_bias else self.output_zero_point

    def set_max_feed_repeat(self, max_feed_repeat):
        if max_feed_repeat is not None:
            self._max_feed_repeat = max_feed_repeat

    def call_bit_exact(self, inputs, training=False, **kwargs):
        inp = inputs[0]
        bias_at_accumulator = self.encode_bias()
        bias = tf.cast(
            self.weight_lossy_elements.bias_decompose(bias_at_accumulator, training=training),
            self.INT_TYPE_TF,
        )
        if not self._precision_split_zp:
            ret = inp + bias
        else:
            input_reshape = tf.reshape(inp, (-1, inp.shape[1], inp.shape[2] // 2, 2, inp.shape[3]))
            input_reshape += bias
            ret = tf.reshape(input_reshape, (-1, inp.shape[1], inp.shape[2], inp.shape[3]))
        ret = tf.cast(ret, self.INT_TYPE_TF)
        ret = self.hw_simulation_by_lossy_element(ret, self.output_lossy_element)
        return ret

    def call_hw_sim(self, inputs, training=False, **kwargs):
        """
        Simulating both "initialization" and "decomposition" bias-add HW implementations,
        using abstraction of 'total_factor' which is 1 for "initialization" implementation.
        """
        bias_at_accumulator = self.encode_bias()
        if not self._precision_split_zp:
            ret = inputs[0] + self.weight_lossy_elements.bias_decompose(bias_at_accumulator, training=training)
        else:
            bias_at_accumulator_q = self.weight_lossy_elements.bias_decompose(bias_at_accumulator, training=training)
            # input shape  [batch, height, width, channels] to [batch, height, width/2, 2,channels]
            input_reshape = tf.reshape(
                inputs[0], (-1, inputs[0].shape[1], inputs[0].shape[2] // 2, 2, inputs[0].shape[3])
            )
            # bias_at_accumulator_q shape [2,channles]
            input_reshape += bias_at_accumulator_q
            # reshape back to [batch, height, width, channels]
            ret = tf.reshape(input_reshape, (-1, inputs[0].shape[1], inputs[0].shape[2], inputs[0].shape[3]))
        return ret

    def call_native(self, inputs, **kwargs):
        return inputs[0] + self.bias

    def encode_bias(self):
        """
        The first line infer_encoding specifies if the residue is merged into the bias, which affects the out_zp.
        This is where we merge the residue into the bias

        Returns
            bias_at_accumulator, after the residue has been merged

        """
        scaled_bias = self.bias / tf.cast(self.output_scale, dtype=self.FLOAT_TYPE_TF)

        merged_residue_val = self.input_zero_points[0] - self.output_zero_point
        return scaled_bias - merged_residue_val

    @property
    def bias_q(self):
        bias_at_accumulator = self.encode_bias()
        bias_decompose_element = self.weight_lossy_elements.bias_decompose
        bias_at_accumulator = bias_decompose_element(bias_at_accumulator)
        return bias_at_accumulator

    @classmethod
    def get_passthru_bias(cls, name, logger=None):
        # The ew_mult bias operates as zp before multiplication
        bias_op = cls(
            name=name,
            bias_initializer=Zeros(),
            trainable=False,
            logger=logger,
        )
        bias_op.is_correctable = False
        bias_op._bias = bias_op.bias_initializer((1,), dtype=bias_op.FLOAT_TYPE_TF)
        return bias_op

    def export_hw_params(self):
        if not isinstance(self.weight_lossy_elements.bias_decompose, BiasQuantElement):
            raise RuntimeError(f"Can't export hw params before configuration is loaded - {self.full_name}")
        bias_at_accumulator = self.encode_bias()
        bias_decompose_element = self.weight_lossy_elements.bias_decompose
        bias_at_accumulator = bias_decompose_element(bias_at_accumulator)

        if self._precision_split_zp:  # because we have 2 zp this is how compiler wants is
            bias_at_accumulator = tf.transpose(bias_at_accumulator)

        if bias_decompose_element.num_decomposition == 0:
            bias_factor = 2**bias_decompose_element.pre_acc_shift
            params = {
                # TODO remove this does not have a use
                "bias": np.int32(bias_at_accumulator),
                "bias_q": np.int32(bias_at_accumulator),
                "bias_factor": np.uint16(bias_factor),
                "bias_feed_repeat": np.uint16(1),
                "precision_split_zp": np.bool_(self._precision_split_zp),
            }
        else:
            int_bias = bias_decompose_element.get_decomposed_int(bias_at_accumulator, training=False)
            factors = bias_decompose_element.factors
            if bias_decompose_element.num_decomposition == 1:
                params = {
                    "bias": np.int32(int_bias[0]),
                    "bias_q": np.int32(int_bias[0]),
                    "bias_factor": np.uint16(factors[0]),
                    "bias_feed_repeat": np.uint16(bias_decompose_element.repeats),
                    "precision_split_zp": np.bool_(self._precision_split_zp),
                }
            # TODO remove this or change it they dont have support jet
            elif bias_decompose_element.num_decomposition == 2:
                bias_total_value = bias_at_accumulator * 2**bias_decompose_element.pre_acc_shift
                params = {
                    "bias": np.int32(int_bias[0]),
                    "bias_q": np.int32(int_bias[0]),
                    "bias_factor": np.uint16(factors[0]),
                    "bias_q_total_value": np.int32(bias_total_value),
                    "bias_factor_a": np.uint16(factors[0]),
                    "bias_factor_b": np.uint16(factors[1]),
                    "bias_q_int8_vec_a": np.int8(int_bias[0]),
                    "bias_q_int8_vec_b": np.int8(int_bias[1]),
                    "bias_feed_repeat": np.uint16(bias_decompose_element.repeats),
                    "precision_split_zp": np.bool_(self._precision_split_zp),
                }
            else:
                raise RuntimeError(f"Unexpected number of decomposition {bias_decompose_element.num_decomposition}")

        return params

    def export_independent_params(self):
        bias_decompose_element = self.weight_lossy_elements.bias_decompose
        params = {
            "bias_feed_repeat": np.array(bias_decompose_element.repeats, np.float32),
            "num_decomposition": np.array(bias_decompose_element.num_decomposition, np.float32),
            "mac_shift": np.array(bias_decompose_element.pre_acc_shift, np.float32),
        }
        for i, factor in enumerate(bias_decompose_element.factors):
            params[f"bias_factor_{i}"] = np.array(factor, np.float32)

        if self.weight_lossy_elements.bias_decompose.num_decomposition == 0:
            params["bias_factor_0"] = np.array(2 ** params["mac_shift"], np.float32)

        params["precision_split_zp"] = np.array(self._precision_split_zp, bool)
        params["merge_residue_into_bias"] = np.array(self.merge_residue_into_bias, bool)

        return params

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.bias_decompose, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        num_decomposition = self.weight_lossy_elements.bias_decompose.num_decomposition
        imported_num_decomposition = params["num_decomposition"]
        if num_decomposition != imported_num_decomposition:
            raise AccelerasImportParamConfigMismatch(
                "num_decomposition",
                num_decomposition,
                imported_num_decomposition,
                self.full_name,
            )
        self.weight_lossy_elements.bias_decompose.repeats = params["bias_feed_repeat"]
        self.weight_lossy_elements.bias_decompose.pre_acc_shift = params["mac_shift"]
        bias_factors = []
        for i in range(max(num_decomposition, 1)):
            bias_factors.append(params[f"bias_factor_{i}"])
        if self.weight_lossy_elements.bias_decompose.num_decomposition == 0:
            bias_factors[0] /= 2**self.weight_lossy_elements.bias_decompose.pre_acc_shift
        self.weight_lossy_elements.bias_decompose.factors = bias_factors
        self._precision_split_zp = params.get("precision_split_zp", False)
        self.merge_residue_into_bias = params.get("merge_residue_into_bias", True)

    def export_quant_weights(self):
        params = {}
        bias_at_accumulator = self.encode_bias()
        bias_quant = self.weight_lossy_elements.bias_decompose.export_as_quant(bias_at_accumulator)
        for i, bias_q in enumerate(bias_quant):
            params[f"bias_{i}"] = np.array(bias_q, np.float32)
        return params

    def export_weights(self):
        if self.bias is not None:
            bias = self.bias.numpy()
        else:
            bias = np.zeros(tuple(self.input_shape[ax] for ax in self.axis))
        return bias

    def define_encodings(self, flow):
        super().define_encodings(flow)
        # mac_shift should always depend on another mac_shift of the layer, as it range could vary.
        # There for we don't define initializer nor quant.
        flow.add_encoding(f"{self.full_name}/mac_shift:0", EncodingType.Scale, scalar=False, shape=())

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        if self.merge_residue_into_bias:
            enc.identity(f"{self.full_name}/output_zero_point:0", self.output_zero_point)
        else:
            enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]

    def _encode_inputs(self, inputs):
        if not self._precision_split_zp:
            return super()._encode_inputs(inputs)

        res_even = inputs[0][:, :, ::2, :] / self.input_scale + self.input_zero_point[0]
        res_odd = inputs[0][:, :, 1::2, :] / self.input_scale + self.input_zero_point[1]

        res = tf.stack([res_even, res_odd], axis=-2)
        res = tf.reshape(res, [res.shape[0], res.shape[1], 2 * res.shape[2], res.shape[4]])
        return [res]


class AddBiasDeconvOp(AddBiasOp):
    """
    This bias will be use for Deconv
    given that this atomic op is a novelty for deconv
    where the bias need to be the same across all the interleaves
    we need to know the number of interleaves need it.

    Args:
        number_of_repeats : Number of repetitions of the bias to be use across the convolution
                           this number should be strides[0] * strides[1] from de deconv kernel
    Attributes:
        short_bias : Bias of the original Deconv

    """

    short_bias: tf.Tensor

    def __init__(
        self,
        name,
        number_of_repeats=4,
        groups=1,
        bias_initializer=None,
        axis=(-1,),
        trainable=True,
        is_correctable=True,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(
            name,
            bias_initializer=bias_initializer,
            axis=axis,
            trainable=trainable,
            is_correctable=is_correctable,
            logger=logger,
            fully_native=fully_native,
            **kwargs,
        )
        self.groups = groups
        self.number_of_repeats = number_of_repeats

    def _build(self, input_shape):
        if len(input_shape) == 1:
            shape = input_shape[self.axis[0]] // self.number_of_repeats
        else:
            shape = [input_shape[ax] // self.number_of_repeats for ax in self.axis]
        self.short_bias = self.add_weight(
            shape=shape,
            trainable=self.trainable,
            initializer=self.bias_initializer,
            name="short_bias",
        )

    def import_weights(self, bias, **kwargs):
        self.pretrained_bias = bias
        if self.built:
            self.short_bias.assign(tf.cast(bias, self.FLOAT_TYPE_TF))
        else:
            self.short_bias = tf.constant(tf.cast(bias, self.FLOAT_TYPE_TF))
        self.bias_initializer = tf.keras.initializers.Constant(bias)

    @property
    def bias(self):
        grouped_bias = tf.reshape(self.short_bias, [self.groups, -1])
        tiled_bias = tf.tile(grouped_bias, [1, self.number_of_repeats])
        return tf.reshape(tiled_bias, [-1])
