import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.split_precision_op import SplitPrecisionHigh, SplitPrecisionLow
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    ConcatAxis,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoPrecisionSplitSigned(HailoDepthwise):
    """
    Splits 15bits layer into 8bits layers concatenated.
    The output will have the scales channel wise so the outputs can be
    sum between then (as many as group numbers)
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w8_a8,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.PRECISION_SPLITTER_SIGNED
    LOW_BITS = 8
    MINIUM_SCALE_RATIO = 8

    def __init__(
        self,
        name: str,
        logger=None,
        groups: int = 1,
        **kwargs,
    ):
        self.input_op = PassthruOp(f"{name}/input_passthough_op", logger=logger)
        self.split_precision_low = SplitPrecisionLow(f"{name}/split_precision_low_op", logger=logger)
        self.split_precision_high = SplitPrecisionHigh(f"{name}/split_precision_high_op", logger=logger)
        self.concat_op = ConcatOp(
            f"{name}/concat_op",
            concat_elements=2,
            axis=ConcatAxis.features,
            group_sizes=[1] * groups,
            logger=logger,
        )

        super().__init__(name=name, kernel_size=[1, 1], logger=logger, **kwargs)
        self.conv_op.kernel = np.array([1])
        self.conv_op.trainable = False
        self.bias_add_op.trainable = False
        self.ratio = 2**self.LOW_BITS
        self.low_bits = self.LOW_BITS
        self.number_of_groups = groups

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w8_a8

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    def set_output_scales_ratio(self, ratio: int):
        if ratio < self.MINIUM_SCALE_RATIO:
            raise ValueError(f"Ratio is two low This will give a lossy multiplication- Ratio {ratio}")
        self.ratio = ratio

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        layer = cls(name=lname, groups=hn_element.get("params", {}).get("groups", 1), logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        params = super().to_hn(out_degree=out_degree)
        params.update({"params": {"groups": self.number_of_groups}})
        return params

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        layer_flow = self._init_flow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_edge(in1, self.input_op, DataPath.LAYER_IN)

        # Set the input to the splitters
        layer_flow.add_edge(self.input_op, self.split_precision_low, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_high, DataPath.LAYER_IN, input_index=0)

        # Finish split High
        layer_flow.add_edge(
            self.split_precision_low, self.split_precision_high, DataPath.LAYER_SPLIT_INPUT, input_index=1
        )

        # Concat Splits
        layer_flow.add_edge(self.split_precision_low, self.concat_op, DataPath.LAYER_SPLIT_INPUT, input_index=0)
        layer_flow.add_edge(self.split_precision_high, self.concat_op, DataPath.LAYER_SPLIT_INPUT, input_index=1)

        # Apply Shifts
        layer_flow.add_edge(self.concat_op, self.conv_op, DataPath.LAYER_SPLIT_INPUT)

        # Conv Part of the Layer
        layer_flow.add_edge(self.conv_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT_WEIGHTS)

        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT_WEIGHTS)

        return layer_flow

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)
        self.create_quant_element_by_data_path(DataPath.LAYER_SPLIT_INPUT, self.low_bits)
        self.conv_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_add_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups
        self.conv_op.quantization_groups_num = quant_groups
        self.act_op.set_quantization_groups(quant_groups)

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        super().start_stats_collection(stats_cfg, output_hist, preact_hist)
        self.split_precision_low.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.split_precision_high.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)

    def create_splits(self, factor=None):
        "Creates the splits with a fix Zp that can be decompose to be mid of Hi and Low"
        self.split_precision_low.trivial_split = False

        zp, input_scales = self.calculate_zp_scales()
        self.input_op.input_scales[0] = input_scales
        self.input_op.input_zero_points[0] = zp
        self._enforce_input_encoding()

    def calculate_zp_scales(self):
        """Here we want to center the values of each 8 bit segment"""
        # We want each split symmetric for each 8 bits segment
        zp = (2 ** (self.get_input_lossy_elements()[0].bits - 1)) + 2 ** (self.LOW_BITS - 1)
        max_x = (2 ** self.get_input_lossy_elements()[0].bits) - 1
        eps = 1e-5

        min_vals, max_vals = self.get_group_input_limvals(self.number_of_groups)[0]
        min_based_scales = np.abs(np.minimum(min_vals, -1 * eps)) / zp
        max_based_scales = np.maximum(max_vals, eps) / (max_x - zp)

        scales = np.maximum(min_based_scales, max_based_scales)
        np.repeat(
            scales,
            self.input_shapes[0][-1] / self.number_of_groups,
        )
        return zp, scales

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        self._enforce_input_encoding()

        max_final_accumulator_by_channel = 127 * self.output_scale

        kernel_scale_matrix_component = self.conv_op.calc_kernel_scale(
            self.conv_op.input_scales, self.act_op.output_scale, 0
        )

        self.conv_op.create_hw_params(
            max_final_accumulator_by_channel,
            weights_clipping,
            optimization_target,
            kernel_scale_matrix_component=kernel_scale_matrix_component,
            hw_shifts=hw_shifts,
            shift_calculate_buffer=0,
        )

        self.bias_add_op.pre_acc_shift = self.conv_op.pre_acc_shift

        nudging = not (self.kernel_scale_forced_to_save)  # dont nudge id the kernel_q_forced is True
        self.act_op.create_hw_params(self.conv_op.accumulator_scale_candidate, optimization_target, nudging=nudging)
        self.enforce_internal_encoding()
        self._create_hw_params_finalize()
        self._has_hw_params = True

    def _enforce_input_encoding(self):
        self.input_op.enforce_encoding()
        self.split_precision_low.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_low.input_zero_points[0] = self.input_op.output_zero_points[0]
        self.split_precision_low.enforce_encoding()

        # Setting low 15 bit Entry
        self.split_precision_high.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_high.input_zero_points[0] = self.input_op.output_zero_points[0]

        # Setting Low Entry
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scales[0]
        self.split_precision_high.input_zero_points[1] = self.split_precision_low.output_zero_points[0]
        self.split_precision_high.enforce_encoding()

        # concat op
        self.concat_op.input_scales[0] = self.split_precision_low.output_scales[0]
        self.concat_op.input_zero_points[0] = self.split_precision_low.output_zero_points[0]

        self.concat_op.input_scales[1] = self.split_precision_high.output_scales[0]
        self.concat_op.input_zero_points[1] = self.split_precision_high.output_zero_points[0]
        self.concat_op.enforce_encoding()

        self.conv_op.input_scales[0] = self.concat_op.output_scales[0]
        self.conv_op.input_zero_points[0] = self.concat_op.output_zero_points[0]

    @property
    def _trivial_scales(self):
        return self.input_op.input_scales[0].shape != () and not all(self.input_op.input_scales[0] == 1)

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforcement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        """

        if self._trivial_scales:
            self._enforce_input_encoding()
            super().enforce_internal_encoding(training=False, **kwargs)
            self.split_precision_low.trivial_split = False

    def enforce_io_encoding(self, training=False, **kwargs):
        scales = self.input_scales[0]
        self.set_output_scale(tf.reshape(tf.stack([scales, self.ratio * scales], axis=1), [-1]), 0)

    @property
    def consumer_input_scale(self):
        return False

    def import_weights(self, layer_params: LayerParams):
        if layer_params.keys():
            super().import_weights(layer_params)
        self.low_bits = layer_params.get("low_bits", self.LOW_BITS)
        self.split_precision_low.import_weights(self.low_bits)
        self.split_precision_high.import_weights(self.low_bits)
        self.ratio = layer_params.get("scale_ratio", 2**self.LOW_BITS)

    def _export_weights(self):
        dict_params = super()._export_weights()
        dict_params.update(
            {
                "low_bits": self.low_bits,
                "scale_ratio": self.ratio,
            }
        )
        return dict_params

    def optimize_ratio(self):
        # This will Shift the bits to the MSB on the High value.
        stats = self.get_input_stats()[0]
        max_final_accumulator_by_channel = np.maximum(
            np.abs(stats.min),
            np.abs(stats.max),
            dtype=np.float32,
        )

        value = np.ceil(np.log2(max_final_accumulator_by_channel / self.input_scale)) - self.low_bits
        # Plus one is because the High vector had 7 bits instead of 8 (Uint15)
        new_ratio = 2 ** max(min(np.max(value) + 1, 8), 3)
        self.ratio = new_ratio
