import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.opt_utils import get_scalar_vector


class ShiftAddOp(BaseAtomicOp):
    """
    Gets 4 inputs, for the input and the weights: HighHigh, HighLow, LowHigh, LowLow
    The layer sums them to double precision results
    """

    FLOAT_TYPE_TF = tf.float64
    num_inputs, num_outputs = 4, 1

    def __init__(self, name, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger, fully_native, **kwargs)
        self.shifts = np.array((0, 0, 0, 0))  # LL, LH, HL, HH
        self.block_ll = False  #  TODO Add logic to set this on.

    def call_native(self, inputs, **kwargs):
        values = tf.add_n(inputs)
        return values

    def call_hw_sim(self, inputs, **kwargs):
        scales_shift = tf.pow(tf.constant([2, 2, 2, 2], dtype=self.FLOAT_TYPE_TF), self.shifts)
        if self.block_ll:
            scales_shift[0] = tf.constant(0, dtype=self.FLOAT_TYPE_TF)
        unstack_scales_shift = tf.unstack(scales_shift)
        shifted_inputs = [inp * sca for inp, sca in zip(inputs, unstack_scales_shift)]

        return self.call_native(shifted_inputs)

    def create_hw_params(self):
        scalar_val = np.array(
            [
                get_scalar_vector(scale / self.output_scale, name=f"{self.full_name}/shift_{ind}")
                for ind, scale in enumerate(self.input_scales)
            ]
        )
        shifts = np.log2(scalar_val)
        if np.any(shifts % 1 != 0) and np.all(0 <= shifts <= 16):
            raise ValueError("Input Scale and output scales dont match for a perfect shift")
        self.shifts = shifts

    def create_weight_quant_element(self, **kwargs): ...

    def export_hw_params(self):
        return {"shift_add_shift": np.array(self.shifts, np.uint8)}

    def enforce_encoding(self, forward=True):
        if forward:
            self.output_scale = self.input_scales[0] / 2 ** self.shifts[0]
        else:
            for ind, shift in enumerate(self.shifts):
                self.input_scales[ind] = self.output_scale * 2**shift

    def export_independent_params(self):
        return {"shifts": self.shifts}

    def import_independent_params(self, params):
        self.shifts = params["shifts"]
