import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.const_op import ConstOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.opt_utils import get_kernel_bits_and_sign_by_precision_mode


class HailoConst(BaseHailoSingleAtomic):
    """
    Single op layer of const output. Th
    Args:
        None
    Examples:
        Examples of use
        >>> op = HailoConst()
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.CONST_INPUT
    OP_NAME = "const_op"

    def __init__(self, name, input_tiles, logger, **kwargs):
        op = ConstOp(f"{name}/{self.OP_NAME}", input_tiles, logger=logger)
        self.is_const_input = True
        super().__init__(name=name, core_op=op, logger=logger, **kwargs)
        self.encoding_const = False

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        pass

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=True)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def enforce_io_encoding(self, **kwargs):
        pass

    def enforce_internal_encoding(self, training=False, **kwargs):
        pass

    def validate_shape(self, input_data):
        pass

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        self.atomic_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        precision_mode = precision_config.precision_mode
        kernel_bits = get_kernel_bits_and_sign_by_precision_mode(precision_mode)[0]
        self.atomic_op.create_weight_quant_element(kernel_bits)

    def _export_layers_qp_defaults(self) -> dict:
        """Const layer does not have Qp in or qpout"""
        params = dict()
        params.update(self.get_qp_out())
        params.update({key: np.array(val, dtype=np.float32) for key, val in self.get_limvals().items()})
        params["zero_point_out"] = np.int32(params["qp_out"][0])
        return params

    def _export_layer_params(self) -> dict:
        """Hailo constant does not have input scales"""
        output_zero_points = self._update_scalar_to_scale(self.output_zero_points, self.output_scales)
        output_scales = self._change_list_to_np_array(self.output_scales)
        output_zero_points = self._change_list_to_np_array(output_zero_points)

        layer_params = {
            "layer_params/output_scales": output_scales,
            "layer_params/output_zero_points": output_zero_points,
            "layer_params/negative_slopes_correction_factor": np.array(
                self._negative_slope_exponent_fix_shift,
                np.float32,
            ),
        }
        return layer_params

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        input_tiles = (
            hn_element["params"]["input_tiles"]
            if "params" in hn_element and "input_tiles" in hn_element["params"]
            else [[1] * 3]
        )
        layer = cls(name=lname, input_tiles=input_tiles, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def iterate_input_ops(self):
        # Syntax to create an empty generator
        return
        yield

    def get_precision_mode(self):
        out_bits = np.unique([out_lossy_elem.bits for out_lossy_elem in self.get_output_lossy_elements()])
        out_bits = out_bits[0]
        if out_bits == 15:
            out_bits = 16
        precision_mode = PrecisionMode(f"a{out_bits}_w{out_bits}_a{out_bits}")
        return precision_mode

    def _get_hn_input_shapes(self):
        return self._hn_element["input_shapes"]
