from typing import Iterable

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ZP_LOW_SPLIT_PRECISION_PIXEL

MAX_PRECISION_ADD_VALUE = 2**14


class SplitPrecisionLow(BaseAtomicOp):
    """
    split the input to low precision
    """

    num_inputs = 1
    num_outputs = 1
    LOW_BITS = 8

    def __init__(self, name, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger, fully_native, **kwargs)
        self._low_precision_th = 1
        self._native_zp = 0
        self._trivial_split = True
        self.low_bits = self.LOW_BITS

    @property
    def trivial_split(self):
        return self._trivial_split

    @trivial_split.setter
    def trivial_split(self, value):
        self._trivial_split = value

    def call_native(self, inputs, **kwargs):
        if self.trivial_split:
            return inputs[0]

        input_encoded = self._encode_inputs(inputs)
        input_encoded = tf.nn.relu(input_encoded)
        input_mod = tf.math.floormod(input_encoded[0], self._get_th_numeric())
        return self._decode_output([input_mod])[0]

    def call_hw_sim(self, inputs, **kwargs):
        return tf.math.floormod(inputs[0], self._get_th_numeric())

    def call_bit_exact(self, inputs, **kwargs):
        threshold = tf.cast(self._get_th_numeric() - 1, inputs[0].dtype)
        return tf.bitwise.bitwise_and(inputs[0], threshold)

    def enforce_encoding(self):
        self.output_scale = self.input_scales[0]
        self.output_zero_point = tf.math.floormod(self.input_zero_points[0], 2**self.low_bits)
        self.output_zero_point = tf.cast(self.output_zero_point, tf.float32)

    def _get_th_numeric(self):
        return 2**self.low_bits

    def create_hw_params(self, *args, **kwargs):
        self.enforce_encoding()
        return super().create_hw_params(*args, **kwargs)

    def create_weight_quant_element(self, **kwargs):
        # Non arithmetic ops shouldn't have any weights
        pass

    def import_weights(self, low_bits):
        self.low_bits = low_bits

    def _compute_output_shape(self, input_shape):
        return input_shape

    def export_independent_params(self):
        return {
            "trivial_split": np.array(self.trivial_split, bool),
        }

    def import_independent_params(self, params):
        self.trivial_split = bool(params.get("trivial_split", False))

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True


class SplitPrecisionHigh(BaseAtomicOp):
    """
    gets the original input and the low precision input is output the subctaction
    """

    num_inputs = 2
    num_outputs = 1
    LOW_BITS = 8

    def __init__(self, name, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger, fully_native, **kwargs)
        self.low_bits = self.LOW_BITS

    def call_native(self, inputs, **kwargs):
        # input 0 is the original input and input 1 is the low precision split
        # we substract the low precision from the input to emulate the split
        return inputs[0] - inputs[1]

    def call_hw_sim(self, inputs, **kwargs):
        return (inputs[0] - inputs[1]) / 2**self.low_bits

    def call_bit_exact(self, inputs, **kwargs):
        threshold = tf.cast(2**self.low_bits - 1, inputs[0].dtype)
        bitwise_and = tf.bitwise.bitwise_and(inputs[0], tf.bitwise.invert(threshold))
        return tf.bitwise.right_shift(bitwise_and, tf.cast(self.low_bits, inputs[0].dtype))

    def enforce_encoding(self):
        self.output_zero_point = (self.input_zero_points[0] - self.input_zero_points[1]) / 2**self.low_bits
        self.output_zero_point = tf.cast(self.output_zero_point, tf.float32)
        self.output_scale = self.input_scales[0] * 2**self.low_bits

    def create_weight_quant_element(self, **kwargs):
        # Non arithmetic ops shouldn't have any weights
        pass

    def import_weights(self, low_bits):
        self.low_bits = low_bits

    def _compute_output_shape(self, input_shape):
        return input_shape[0]

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True


class PrecisionAddOp(BaseAtomicOp):
    """
    Gets 4 inputs, for the input and the weights: HighHigh, HighLow, LowHigh, LowLow
    The layer sums them to double precision results
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name, num_decompositions=4, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger, fully_native, **kwargs)
        self.kernel_scales = np.ones(num_decompositions)
        self.num_decompositions = num_decompositions
        self.shift = 0

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        output_channels = self.input_shapes[0][-1] // self.num_decompositions
        for i in range(self.num_decompositions):
            if i == 0:
                res = inp[:, :, :, i * output_channels : (i + 1) * output_channels]
            else:
                res += inp[:, :, :, i * output_channels : (i + 1) * output_channels]
        return res

    def call_hw_sim(self, inputs, **kwargs):
        inp = inputs[0]

        kernel_q_repeat = tf.repeat(self.kernel_q, self.output_channels)
        kernel_q_repeat = tf.cast(kernel_q_repeat, tf.float32)
        mul_res = inp * kernel_q_repeat
        shift_res = mul_res * 2**-self.shift
        bias_res = shift_res + self.bias_q
        return self.call_native([bias_res])

    def _build(self, input_shapes):
        self.output_channels = input_shapes[-1] // self.num_decompositions

    @property
    def bias_q(self):
        kernel_q_repeat = tf.repeat(self.kernel_q, self.output_channels)
        kernel_q_repeat = tf.cast(kernel_q_repeat, tf.float32)
        bias_q = -kernel_q_repeat * self.input_zero_points[0] * 2**-self.shift
        return bias_q

    def create_hw_params(self, *args, **kwargs):
        self.shift = 0

    def get_kernel_scales(self):
        max_scale = np.max(self.input_scales[0])  # hh scale
        max_weight = MAX_PRECISION_ADD_VALUE
        kernel_scales = [0] * self.num_decompositions
        input_scales = self.input_scales[0]
        for i in range(self.num_decompositions):
            input_scale = input_scales[i * self.output_channels] if isinstance(input_scales, Iterable) else input_scales
            kernel_scales[i] = max_scale / input_scale / max_weight
        return tf.convert_to_tensor(kernel_scales)

    @property
    def kernel_q(self):
        return 1 / self.kernel_scales

    def enforce_encoding(self, forward=True):
        if forward:
            self.kernel_scales = self.get_kernel_scales()
            kernel_scale_repeat = tf.repeat(self.kernel_scales, self.output_channels)
            kernel_scale_repeat = tf.cast(kernel_scale_repeat, tf.float32)
            self.output_scale = tf.cast(self.input_scales[0], tf.float32) * kernel_scale_repeat * 2**self.shift
            self.output_scale = self.output_scale[0 : self.output_channels]
            self.output_zero_point = np.array(0.0)
        else:
            kernel_scale_repeat = tf.repeat(self.kernel_scales, self.output_channels)
            kernel_scale_repeat = tf.cast(kernel_scale_repeat, tf.float32)
            self.input_scale = (
                tf.concat([tf.cast(self.output_scale, self.FLOAT_TYPE_TF)] * self.num_decompositions, -1)
                / kernel_scale_repeat
                / 2**self.shift
            )

    def _get_data_low_high_ratio(self):
        return 1

    def _get_weight_low_high_ratio(self):
        return 1

    def create_weight_quant_element(self, **kwargs):
        # Non arithmetic ops shouldn't have any weights
        pass

    def export_independent_params(self):
        return {
            "kernel_scales": np.array(self.kernel_scales, np.float32),
            "shift": np.array(self.shift, np.float32),
        }

    def export_hw_params(self):
        return {
            "kernel": np.array(self.kernel_q, np.uint16),
            "bias": np.array(self.bias_q, np.int32),
            "output_stage/mult_shift": np.array(self.shift, np.uint8),
        }

    def import_independent_params(self, params):
        self.kernel_scales = params["kernel_scales"]
        self.shift = params["shift"]


class PrecisionSplitPixelOp(BaseNonArithmeticAtomicOp):
    """
    Split the input into low and high precision, while dabbling the pixel dimension.
    """

    num_inputs = 1
    num_outputs = 1
    LOW_BITS = 8

    def __init__(
        self,
        name: str,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._trivial_split = True
        self.low_bits = self.LOW_BITS

    @property
    def trivial_split(self):
        return self._trivial_split

    @trivial_split.setter
    def trivial_split(self, value):
        self._trivial_split = value

    def call_hw_sim(self, inputs, **kwargs):
        inp = inputs[0]
        if self.trivial_split:
            high = inp
            low = tf.zeros_like(high)
        else:
            threshold = 2**self.low_bits
            inp = tf.nn.relu(inp)
            low = tf.math.floormod(inp, threshold)
            high = (inp - low) / 2**self.low_bits
        return tf.reshape(
            tf.concat([tf.expand_dims(low, axis=-2), tf.expand_dims(high, axis=-2)], axis=-2),
            [-1, inp.shape[1], 2 * inp.shape[2], inp.shape[3]],
        )

    @property
    def _zp_comp_low(self):
        if self._trivial_split:
            return 0.0
        return ZP_LOW_SPLIT_PRECISION_PIXEL

    def call_native(self, inputs, **kwargs):
        if self.trivial_split:
            return self.call_hw_sim(inputs)

        input_encoded = self._encode_inputs(inputs)
        output_encoded = self.call_hw_sim(input_encoded)
        return self._decode_output([output_encoded])[0]

    def _decode_output(self, outputs):
        res_even = (outputs[0][:, :, ::2, :] - self._zp_comp_low) * self.output_scale
        res_odd = (outputs[0][:, :, 1::2, :] - self.output_zero_point) * self.output_scale
        res = tf.stack([res_even, res_odd], axis=-2)
        res = tf.reshape(res, [res.shape[0], res.shape[1], 2 * res.shape[2], res.shape[4]])
        return [res]

    def _compute_output_shape(self, input_shape):
        return [input_shape[0], input_shape[1], 2 * input_shape[2], input_shape[3]]

    def enforce_encoding(self):
        if self.trivial_split:
            self.output_scale = self.input_scale
            self.output_zero_point = self.input_zero_point
        else:
            self.output_scale = self.input_scale * 2**self.low_bits
            self.output_zero_point = np.floor(np.float32(self.input_zero_point / (2**self.low_bits)))

    def export_independent_params(self):
        return {
            "trivial_split": np.array(self.trivial_split, bool),
        }

    def import_independent_params(self, params):
        self.trivial_split = bool(params.get("trivial_split", False))
