import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.split_precision_op import (
    PrecisionSplitPixelOp,
    SplitPrecisionHigh,
    SplitPrecisionLow,
)
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ZP_LOW_SPLIT_PRECISION_PIXEL,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams


class HailoPrecisionSplit(BaseHailoLayer):
    """
    Splits 16bits layer into two 8bits layers concatenated.
    split input into high and low and concatinate it.
    split_precision_op: split the input to high and low
    """

    SUPPORTED_PRECISION_MODE = {PrecisionMode.a16_w16, PrecisionMode.a16_w16_a8}
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    LOW_BITS = 8
    _hn_type = LayerType.PRECISION_SPLITTER

    def __init__(
        self,
        name: str,
        logger=None,
        **kwargs,
    ):
        self.input_op = PassthruOp(f"{name}/input_passthough_op", logger=logger)
        self.split_precision_low = SplitPrecisionLow(f"{name}/split_precision_low_op", logger=logger)
        self.split_precision_high = SplitPrecisionHigh(f"{name}/split_precision_high_op", logger=logger)
        self.output_op1 = PassthruOp(f"{name}/output_passthough_op1", logger=logger)
        self.output_op2 = PassthruOp(f"{name}/output_passthough_op2", logger=logger)

        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
        super().__init__(name=name, logger=logger, **kwargs)
        for op in self.atomic_ops:
            op.fully_native = True
        self.low_bits = self.LOW_BITS

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        layer = cls(name=lname, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        layer_flow = self._init_flow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        out2 = layer_flow.add_output()

        layer_flow.add_edge(in1, self.input_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_low, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_high, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(self.split_precision_low, self.split_precision_high, DataPath.LAYER_OUT, input_index=1)
        layer_flow.add_edge(self.split_precision_low, self.output_op1, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.split_precision_high, self.output_op2, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op1, out1, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op2, out2, DataPath.LAYER_OUT)

        return layer_flow

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        super().start_stats_collection(stats_cfg, output_hist, preact_hist)
        self.split_precision_low.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.split_precision_high.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)

    def create_splits(self, factor=None, translation_config=None):
        self.split_precision_low.trivial_split = False
        self.split_precision_low.create_input_encoding_candidates(
            0,
            input_lossy_external=self.input_lossy_element_external,
            factor=factor,
            translation_config=translation_config,
        )

        self.split_precision_low.enforce_encoding()
        self.split_precision_high.input_scales[0] = self.split_precision_low.input_scale
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scale
        self.split_precision_high.enforce_encoding()

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        for op in self.atomic_ops:
            op.fully_native = False
        self._has_hw_params = True

    @property
    def _trivial_scales(self):
        # very hacky
        return self.input_op.input_scales[0].shape != () and not all(self.input_op.input_scales[0] == 1)

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforcement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        """
        if self._trivial_scales:
            self.input_op.enforce_encoding()
            self.split_precision_low.input_scales[0] = self.input_op.output_scales[0]
            self.split_precision_high.input_scales[0] = self.input_op.output_scales[0]
            self.split_precision_low.input_zero_points[0] = self.input_op.output_zero_points[0]
            self.split_precision_low.enforce_encoding()
            self.split_precision_high.input_zero_points[0] = self.input_op.output_zero_points[0]

            self.split_precision_high.input_scales[1] = self.split_precision_low.output_scales[0]
            self.split_precision_high.input_zero_points[1] = self.split_precision_low.output_zero_points[0]
            self.split_precision_high.enforce_encoding()
            self.output_op1.input_scales[0] = self.split_precision_low.output_scales[0]
            self.output_op1.input_zero_points[0] = self.split_precision_low.output_zero_points[0]
            self.output_op1.forward_encoding()

            self.output_op2.input_scales[0] = self.split_precision_high.output_scales[0]
            self.output_op2.input_zero_points[0] = self.split_precision_high.output_zero_points[0]
            self.output_op2.forward_encoding()
            self.split_precision_low.trivial_split = False

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        self.enforce_internal_encoding()

    @property
    def consumer_input_scale(self):
        return False

    @property
    def homogeneous(self):
        return False

    def import_weights(self, layer_params: LayerParams):
        self.low_bits = layer_params.get("low_bits", self.LOW_BITS)
        self.split_precision_low.import_weights(self.low_bits)
        self.split_precision_high.import_weights(self.low_bits)

    def _export_weights(self):
        dict_params = {"low_bits": self.low_bits}
        return dict_params

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w16_a8

    def optimize_ratio(self):
        # This will Shift the bits to the MSB on the High value.
        stats = self.get_input_stats()[0]
        max_final_accumulator_by_channel = np.maximum(
            np.abs(stats.min),
            np.abs(stats.max),
            dtype=np.float32,
        )

        value = np.ceil(np.log2(max_final_accumulator_by_channel / self.input_scale)) - self.low_bits
        # Plus one is because the High vector had 7 bits instead of 8 (Uint15)
        new_ratio = 2 ** max(min(np.max(value) + 1, 8), 3)
        return new_ratio


class HailoPrecisionSplitPixels(BaseHailoSingleAtomic):
    """
    TODO: add docstring
    """

    SUPPORTED_PRECISION_MODE = {PrecisionMode.a16_w16, PrecisionMode.a16_w16_a8}
    _hn_type = LayerType.PRECISION_SPLITTER

    def __init__(
        self,
        name: str,
        logger=None,
        **kwargs,
    ):
        split_op = PrecisionSplitPixelOp(f"{name}/split_precision_op", logger=logger)
        super().__init__(name=name, core_op=split_op, logger=logger, **kwargs)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        layer = cls(name=lname, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def create_splits(self, factor=None, translation_config=None):
        self.atomic_op.trivial_split = False
        self.atomic_op.create_input_encoding_candidates(
            0,
            input_lossy_external=self.input_lossy_element_external,
            factor=factor,
            translation_config=translation_config,
            split_precision_zp=ZP_LOW_SPLIT_PRECISION_PIXEL,
        )

        self.atomic_op.enforce_encoding()

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w16_a8

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE  # by default same as emulator support
        elif arch in {OptimizationTarget.EMULATION}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE
        else:
            supported_precision_mode = set()
        return supported_precision_mode
