from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement, RoundDownQuantElement


@dataclass
class ShiftWeightsLossy(BaseWeightLossyElements):
    factor: RoundDownQuantElement


class ShiftOp(BaseAtomicOp):
    """
    Emulate shift operation
    """

    weight_lossy_elements: ShiftWeightsLossy
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = ShiftWeightsLossy(
            factor=IdentityElement(),
        )
        self.shift = 0

    def call_native(self, inputs, **kwargs):
        return inputs[0]

    def call_hw_sim(self, inputs, **kwargs):
        after_shift = inputs[0] / 2.0**self.shift
        return self.weight_lossy_elements.factor(after_shift)

    def call_bit_exact(self, inputs, **kwargs):
        after_shift = tf.bitwise.right_shift(inputs[0], self.shift)
        restuls = self.hw_simulation_by_lossy_element(after_shift, self.weight_lossy_elements.factor)
        self._verify_data_dtype(restuls, self.weight_lossy_elements.factor.bits, False, "exp_out_numertor")
        return restuls

    def is_differentiable(self) -> bool:
        return False

    def export_quant_weights(self):
        return {}

    def export_weights(self):
        return {}

    def enforce_encoding(self, *args, **kwargs):
        self.output_scale = self.input_scale * 2**self.shift

    def create_weight_quant_element(self, bits=16):
        self.weight_lossy_elements = ShiftWeightsLossy(
            factor=QuantElement(bits=bits, signed=False, name=f"{self.full_name}/qe:factor"),
        )

    def create_hw_params(self, shift):
        self.shift = shift
        self.enforce_encoding()

    def export_independent_params(self):
        return {
            "shift": np.float32(self.shift),
        }

    def import_independent_params(self, params):
        self.shift = np.float32(params["shift"])

    def export_hw_params(self):
        params = self.export_independent_params()
        return {
            f"{self.name}_shift": params["shift"].astype(np.uint8),
        }

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def _compute_output_shape(self, input_shape):
        return input_shape
