from dataclasses import dataclass

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    APU_MANTISSA_BITS,
    APU_OFFSET_BITS,
    APU_OFFSET_BITS_D,
)


@dataclass
class ActivationWeightsLossy(BaseWeightLossyElements):
    mantissa: BaseLossyElement
    offset: BaseLossyElement


class ActivationLinearOp(BaseAtomicOp):
    """
    Emulates a linear activation function. Meaning APU with only linear function that is used only for scaling and shifting.

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = ActivationWeightsLossy(
            mantissa=IdentityElement(name=f"{self.full_name}/ie:mantissa"),
            offset=IdentityElement(name=f"{self.full_name}/ie:offset"),
        )
        self.mantissa = 1.0
        self.exponent = 0.0
        self.offset = 0.0
        self.final_shift = 0.0
        self.scale_factor = 1.0  # scale_factor = self.output_scale / self.input_scale

    def call_native(self, inputs, **kwargs):
        return inputs

    def call_hw_sim(self, inputs, **kwargs):
        inp = inputs[0]
        post_matissa = inp * self.mantissa_q
        post_exponent = post_matissa * (-(2**self.exponent_q))
        post_offset = post_exponent + self.offset_q
        post_final_shift = post_offset * (-(2**self.final_shift_q))
        return post_final_shift

    def enforce_encoding(self, *args, **kwargs):
        self.output_scale = self.input_scale / self.mantissa * 2 ** (self.exponent + self.final_shift)
        self.output_zero_point = self.call_hw_sim(self.input_zero_point)

    def create_hw_params(self, *args, **kwargs):
        self.offset = self.output_zero_point * 2**self.final_shift
        mantissa, exponent = np.frexp(2**self.final_shift / self.scale_factor)
        self.exponent = APU_MANTISSA_BITS - exponent
        self.mantissa = mantissa * 2**APU_MANTISSA_BITS

    def create_weight_quant_element(self, optimization_target, **kwargs):
        apu_io_mode = (self.input_lossy_element.bits, self.output_lossy_element.bits)
        offset_bits = APU_OFFSET_BITS if apu_io_mode == (16, 8) else APU_OFFSET_BITS_D
        self.weight_lossy_elements = ActivationWeightsLossy(
            mantissa=QuantElement(
                signed=False,
                bits=APU_MANTISSA_BITS,
                wraparound=False,
                name=f"{self.full_name}/qe:mantissa",
            ),
            offset=QuantElement(signed=True, bits=offset_bits, wraparound=False, name=f"{self.full_name}/qe:offset"),
        )

    @property
    def mantissa_q(self):
        return self.weight_lossy_elements.mantissa(self.mantissa)

    @property
    def offset_q(self):
        return self.weight_lossy_elements.offset(self.offset)

    def export_independent_params(self):
        return {
            "mantissa": np.array(self.mantissa, np.float32),
            "exponent": np.array(self.exponent, np.float32),
            "offset": np.array(self.offset, np.float32),
            "final_shift": np.array(self.final_shift, np.float32),
            "scale_factor": np.array(self.scale_factor),
        }

    def import_independent_params(self, params):
        self.mantissa = params["mantissa"]
        self.exponent = params["exponent"]
        self.offset = params["offset"]
        self.final_shift = params["final_shift"]
        self.scale_factor = params["scale_factor"]

    def export_hw_params(self):
        return {
            "mantissa_q": self.mantissa_q,
            "exponent": self.exponent,
            "offset": self.offset_q,
        }
