from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputQuantElement
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import limvals_to_zp_scale


@dataclass
class InputWeightsLossy(BaseWeightLossyElements):
    inputs: BaseLossyElement


class ConstOp(BaseAtomicOp):
    """
    This class emulates the reduce max operation
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, input_tiles: list, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._input_tiles = input_tiles
        self.weight_lossy_elements = InputWeightsLossy(
            inputs=IdentityElement(name=f"{self.full_name}/ie:input"),
        )
        self.const_data = None

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """

    def create_weight_quant_element(self, input_bits=8):
        self.weight_lossy_elements = InputWeightsLossy(
            inputs=APUOutputQuantElement(bits=input_bits, name=f"{self.full_name}/qe:const"),
        )

    def import_input_encoding(self, params, input_index):
        pass

    def import_weights(self, layer_params: LayerParams):
        self.const_data = tf.cast(layer_params["const_data"], tf.float32)

    @property
    def const_data_q(self):
        # slices the output scale to match the const data shape if needed
        output_scale = (
            self.output_scale[: self.const_data.shape[-1]] if len(self.output_scale.shape) != 0 else self.output_scale
        )
        const_data_q = self.const_data / output_scale + self.output_zero_point
        return self.weight_lossy_elements.inputs(const_data_q)

    def export_weights(self):
        return {"const_data": self.const_data.numpy()}

    def export_hw_params(self):
        return {"const_data": np.array(self.const_data_q.numpy(), dtype=np.int16)}

    def create_input_encoding_candidates(self, input_index, input_lossy_external=None, translation_config=None):
        limvals = [np.min(self.const_data), np.max(self.const_data)]
        bit_reducer = self.weight_lossy_elements.inputs
        zp, scale, _ = limvals_to_zp_scale(
            limvals,
            bit_reducer,
            name="const",
            logger=None,
            force_range_to_cover_zero=True,
        )
        self.input_zero_points[input_index] = zp
        self.input_scales[0] = scale

    def export_quant_weights(self):
        return {"const_data": self.const_data_q.numpy()}

    def export_input_encoding(self, input_index):
        return {
            f"input_scale:{input_index}": np.array(self.input_scales[input_index], np.float32),
            f"input_zero_point:{input_index}": np.array(self.input_zero_points[input_index], np.float32),
            f"input_bits:{input_index}": np.array(self.weight_lossy_elements.inputs.bits, np.float32),
        }

    def _general_call(self, inputs, const_data):
        # tiles the const data according to the tiling ratios
        const_data = tf.tile(const_data, self._input_tiles[0])
        # expands to the batch size
        return tf.repeat(tf.expand_dims(const_data, axis=0), tf.shape(inputs[0])[0], axis=0)

    def call_native(self, inputs, **kwargs):
        return self._general_call(inputs, self.const_data)

    def call_hw_sim(self, inputs, **kwargs):
        return self._general_call(inputs, self.const_data_q)

    def _encode_inputs(self, inputs):
        """
        The input in const op is ignored, but the batch size is used.
        """
        return inputs

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def _compute_output_shape(self, input_shape):
        return (
            input_shape[0],
            *[int(dim * ratio) for dim, ratio in zip(self.const_data.shape, self._input_tiles[0])],
        )
