import copy

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_fused_bbox_decoder import HailoFusedBboxDecoder
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_normalization import HailoLayerNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_nms import HailoNMS
from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import HailoPostprocess
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import PrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_OPTIMIZATION_TARGET,
    BiasMode,
    ExplicitPrecisionModes,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasUnsupportedError,
)
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_input_bits_by_precision_mode,
    get_output_bits_by_precision_mode,
    get_output_predecessor_precision_mode_by_bits,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

explicit_dict = {
    (PrecisionMode.a8_w8, 8): PrecisionMode.a8_w8_a8,
    (PrecisionMode.a8_w8, 15): PrecisionMode.a8_w8_a16,
    (PrecisionMode.a8_w8, 16): PrecisionMode.a8_w8_a16,
    (PrecisionMode.a8_w4, 8): PrecisionMode.a8_w4_a8,
    (PrecisionMode.a8_w4, 15): PrecisionMode.a8_w4_a16,
    (PrecisionMode.a8_w4, 16): PrecisionMode.a8_w4_a16,
    (PrecisionMode.a16_w16, 8): PrecisionMode.a16_w16_a8,
    (PrecisionMode.a16_w16, 15): PrecisionMode.a16_w16_a16,
    (PrecisionMode.a16_w16, 16): PrecisionMode.a16_w16_a16,
    (PrecisionMode.a16_w8, 8): PrecisionMode.a16_w8_a8,
    (PrecisionMode.a16_w8, 16): PrecisionMode.a16_w8_a16,
    (PrecisionMode.a16_w8, 15): PrecisionMode.a16_w8_a16,
}


class CreateMixedPrecision(OptimizationAlgorithm):
    """
    Assigns precision mode to every layer. Precision mode is passed to the compiler as a triplet (input_bit_width, weight_bit_width, output_bit_width).
    Determined by a traversal on the model graph and the precision modes provided by the alls.
    """

    SUPPORTED_AUTO4BIT = {HailoConv, HailoConvAdd, HailoDense}

    def __init__(self, model, model_config, logger_level: int, logger=None, for_infer=False):
        super().__init__(
            model=model,
            model_config=model_config,
            name="Mixed Precision",
            logger_level=logger_level,
            logger=logger,
        )
        self._generated_statistics = dict()
        self._for_infer = for_infer
        self._shortcut_added_index = len([lname for lname in self._model.layers.keys() if "/precision_change" in lname])
        shapes = [(None,) + shape for shape in self._model.get_input_shapes()]
        self._model.compute_output_shape(shapes)

    def get_algo_config(self):
        return self._model_config.precision_config

    def log_config(self):
        pass

    def _setup(self):
        cfg = self.get_algo_config()
        for lname, layer in self._model.layers.items():
            # TODO: maybe create empty config if not in dict and fill it?
            if lname not in cfg.layers:
                cfg.layers[lname] = layer.get_default_precision_config()
            elif lname in cfg.layers:
                cfg.layers[lname].fill_default_config(layer)

    def _run_int(self):
        if not self._for_infer:
            if self._model_config.compression_params.auto_4bit_weights_ratio > 0:
                self.greedy_auto4bit()
            elif self._model_config.compression_params.auto_16bit_weights_ratio > 0:
                self.auto16bit()
            self.fix_precision_config_hw_support_and_partial_verify()
        algo_cfg = self.get_algo_config()
        algo_cfg = self._propagate_16bits_config(algo_cfg)
        algo_cfg = self._propagate_precision_config(algo_cfg)
        self._model.import_config(self._model_config)

    def _switch_specific_layers_to_16bits(self, algo_cfg: PrecisionConfig):
        """
        Changes specific layers (currently HailoFusedBboxDecoder) to operate in 16 bits precision mode.
        """
        contain_hailo_fused_bbox_decoder = False
        for lname in reversed(list(self._model.flow.toposort())):
            if isinstance(self._model.layers[lname], HailoFusedBboxDecoder):
                contain_hailo_fused_bbox_decoder = True
                break
        if not contain_hailo_fused_bbox_decoder:
            return algo_cfg

        def _set_layer_and_neighbours_to_16bits(lname: str):
            """
            An auxiliary function that converts the lname layer into 16 bits. Additionally,
            the previous layers are switched to a8_w8_a16, and the successor units to a16_w16.
            """
            algo_cfg.layers[lname].precision_mode = PrecisionMode.a16_w16_a16
            for suc in self._model.flow.successors_sorted(lname):
                algo_cfg.layers[suc].precision_mode = PrecisionMode.a16_w16_a16
            for pred in self._model.flow.predecessors_sorted(lname):
                if self._is_input_16bits(algo_cfg.layers[pred].precision_mode):
                    # Changes a16_w16_a8 to a16_w16_a16, too
                    algo_cfg.layers[suc].precision_mode = PrecisionMode.a16_w16_a16
                else:
                    algo_cfg.layers[pred].precision_mode = PrecisionMode.a8_w8_a16

        for lname in reversed(list(self._model.flow.toposort())):
            if isinstance(self._model.layers[lname], HailoNMS):
                _set_layer_and_neighbours_to_16bits(lname)

        return algo_cfg

    def _propagate_precision_config(self, algo_cfg: PrecisionConfig):
        """Set output precision config for each layer"""
        for lname in self._model.flow.toposort():
            layer_precision_mode = algo_cfg.layers[lname].precision_mode
            if layer_precision_mode in ExplicitPrecisionModes:
                continue
            successors = self._model.flow.successors_sorted(lname)
            suc_precision_mode = [algo_cfg.layers[suc].precision_mode for suc in successors]
            suc_input_bits = [get_input_bits_by_precision_mode(mode) for mode in suc_precision_mode]
            # Take care of the output layer with respect to the previous node.
            if len(suc_input_bits) == 0:
                # Get the input bits by the precision mode of the previous layer
                predecessors = self._model.flow.predecessors_sorted(lname)
                # Keep this for the case that an output layer has more than one predecessors (currently unsupported)
                if len(predecessors) > 1:
                    AccelerasUnsupportedError(
                        f"Output layer {lname} does not support multi predecessors, "
                        + f"but {len(predecessors)} predecessor layers are given.",
                    )
                pred_precision_mode = algo_cfg.layers[predecessors[0]].precision_mode
                pred_output_bits = get_output_bits_by_precision_mode(pred_precision_mode)
                desired_precision_mode = get_output_predecessor_precision_mode_by_bits(pred_output_bits)
                explicit_mode = self.get_explicit_mode(desired_precision_mode, pred_output_bits, lname)
            else:
                explicit_mode = self.get_explicit_mode(layer_precision_mode, np.max(suc_input_bits), lname)

            # add shortcut if the input precision of the successors are not matched
            if not all(x == suc_input_bits[0] for x in suc_input_bits):
                if explicit_mode in self._model.layers[lname].SUPPORTED_PRECISION_MODE:
                    targets = [successors[i] for i in range(len(successors)) if suc_input_bits[i] == 8]
                    self._add_shortcut_layer(lname, targets, PrecisionMode.a16_w16_a8)
                else:
                    targets = [successors[i] for i in range(len(successors)) if suc_input_bits[i] == 15]
                    self._add_shortcut_layer(lname, targets, explicit_mode)
                    layer_output_bits = get_input_bits_by_precision_mode(layer_precision_mode)
                    explicit_mode = self.get_explicit_mode(layer_precision_mode, layer_output_bits, lname)
            # add shortcut if the precision transition is not supported
            elif not isinstance(self._model.layers[lname], BaseHailoNonNNCoreLayer):
                if explicit_mode not in self._model.layers[lname].SUPPORTED_PRECISION_MODE:
                    targets = [successors[i] for i in range(len(successors))]
                    self._add_shortcut_layer(lname, targets, explicit_mode)
                    layer_output_bits = get_input_bits_by_precision_mode(layer_precision_mode)
                    explicit_mode = self.get_explicit_mode(layer_precision_mode, layer_output_bits, lname)

            algo_cfg.layers[lname].precision_mode = explicit_mode
        return algo_cfg

    def _add_shortcut_layer(self, source, targets, explicit_mode):
        if not isinstance(targets, list):
            targets = [targets]
        shape = list(self._model.layers[source].output_shapes[0])
        shape[0] = -1
        hn = {
            "type": "activation",
            "input": source,
            "output": targets,
            "input_shapes": [shape],
            "output_shapes": [shape],
            "params": {"activation": "linear"},
        }
        # add layer to model
        scope_name, layer_name = source.split("/")
        block_name, _ = self.get_block_and_layer_names(layer_name)
        source_num_outputs = self._model.layers[source].num_outputs
        if source_num_outputs == 1:
            target_groups = [targets]
        else:
            target_groups = [[target] for target in targets]
        for index, target_group in enumerate(target_groups):
            output_index_prefix = f"_{index}" if source_num_outputs > 1 else ""
            shortcut_name = (
                f"{scope_name}/{block_name}precision_change{self._shortcut_added_index}{output_index_prefix}"
            )
            shortcut_layer = HailoStandaloneActivation.from_hn(lname=shortcut_name, hn_element=hn)

            edges = [(source, target) for target in target_group]
            self._model.add_layer(shortcut_layer, edges)
            # add layer to configuration
            precision_cfg = LayerPrecisionConfig(
                precision_mode=explicit_mode,
                bias_mode=BiasMode.single_scale_decomposition,
                quantization_groups=1,
            )
            self._model_config.precision_config.layers[shortcut_name] = precision_cfg
        self._shortcut_added_index += 1

        return shortcut_name

    def get_explicit_mode(self, current_mode, suc_input_bits, lname):
        # handle special case of const layer
        if isinstance(self._model.layers[lname], HailoConst):
            if suc_input_bits == 15:
                return PrecisionMode.a16_w16_a16
            elif suc_input_bits == 8:
                return PrecisionMode.a8_w8_a8
            else:
                raise AccelerasUnsupportedError(f"nor supported output bits {suc_input_bits} in {lname}")
        # handle rest of the layers
        key = (current_mode, suc_input_bits)
        if key not in explicit_dict:
            raise AccelerasUnsupportedError(
                f"precision modes {current_mode} with output bits {suc_input_bits} in {lname}",
            )
        return explicit_dict[key]

    def _propagate_16bits_config(self, algo_cfg: PrecisionConfig):
        """
        propagate precision config via non arithmetic layers that can't change the precision
        """
        for _ in range(100):  # using for to prevent infinite loop
            was_changed = False
            for lname in reversed(list(self._model.flow.toposort())):
                precision_mode = algo_cfg.layers[lname].precision_mode
                # set the output to be 16 bits if the predecessor is 16 bits
                if isinstance(self._model.layers[lname], HailoPostprocess):
                    was_changed = self._propagate_16bit_from_pred(lname, algo_cfg, precision_mode)
                if lname in self._model.flow.output_nodes:
                    was_changed = self._propagate_16bit_from_pred(lname, algo_cfg, precision_mode)
                # set transparent layers precision mode
                if not self._is_input_16bits(precision_mode):
                    continue
                for pred in self._model.flow.predecessors_sorted(lname):
                    if self._model.layers[pred].is_precision_transparent:
                        pred_target_mode = self._get_pred_16bits_mode(precision_mode)
                        if algo_cfg.layers[pred].precision_mode != pred_target_mode:
                            algo_cfg.layers[pred].precision_mode = pred_target_mode
                            was_changed = True
            if not was_changed:
                break
        return algo_cfg

    def _propagate_16bit_from_pred(self, lname, algo_cfg, precision_mode):
        output_is_16bits = self._is_input_16bits(precision_mode)
        preds = self._model.flow.predecessors_sorted(lname)
        if np.any([isinstance(self._model.layers[pred], HailoLayerNormalization) for pred in preds]):
            return False
        pred_precisions = [algo_cfg.layers[pred].precision_mode for pred in preds]
        pred_is_16_bits = [self._is_input_16bits(pred_precision) for pred_precision in pred_precisions]
        if np.any(pred_is_16_bits) and not output_is_16bits:
            algo_cfg.layers[lname].precision_mode = PrecisionMode.a16_w16
            return True
        return False

    def _is_input_16bits(self, precision_mode):
        return precision_mode in [
            PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w16_a8,
            PrecisionMode.a16_w16,
            PrecisionMode.a16_w8_a8,
            PrecisionMode.a16_w8_a16,
            PrecisionMode.a16_w4_a16,
            PrecisionMode.a16_w4_a8,
        ]

    def _get_pred_16bits_mode(self, precision_mode):
        if precision_mode == PrecisionMode.a16_w16:
            return precision_mode
        if precision_mode in [PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8]:
            return PrecisionMode.a16_w16_a16

    def should_skip_algo(self):
        return False

    def export_statistics(self):
        return self._generated_statistics

    def _get_valid_layer_cfg(self, lname, cfg: LayerPrecisionConfig):
        layer = self._model.layers[lname]
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return cfg
        arch = self.get_algo_config().target
        cfg = copy.deepcopy(cfg)
        bias_mode = cfg.get("bias_mode")
        if bias_mode not in layer._get_bias_mode_supported_in_hw(arch):
            cfg.pop("bias_mode", None)
        precision_mode = cfg.get("precision_mode")
        if precision_mode not in layer._get_precision_mode_supported_in_hw(arch):
            cfg.pop("precision_mode", None)
        qgroups = cfg.get("quantization_groups")
        if (qgroups is None) or (not layer.is_quantization_groups_supported_in_hw(qgroups, arch)):
            cfg.pop("quantization_groups", None)
        self._verify_nms_precision_in_yolox_model(lname, cfg)
        return cfg

    def _verify_nms_precision_in_yolox_model(self, lname: str, cfg: LayerPrecisionConfig):
        """
        This function verifies that a user did not erroneously change the nms layer in the
            zip_yolox_hailo_coco_nms_fcn_hailo model to work at an unsupported precision mode.
            This function will be replaced in the future by a generic behavior for the nms\u200b
            (SDK-43564).

        Args:
            lname (str): layer name.

        """

        def _is_fused_nms_model(lname):
            """
            A network with one or more HailoNMS & HailoFusedBboxDecoder layers qualifies as
            a fused_nms network.
            """
            if isinstance(self._model.layers[lname], HailoNMS):
                # Look for HailoFusedBboxDecoder anywhere in the network
                for layer_name in reversed(list(self._model.flow.toposort())):
                    if isinstance(self._model.layers[layer_name], HailoFusedBboxDecoder):
                        return True
            return False

        if not _is_fused_nms_model(lname):
            return
        if (
            cfg.get("precision_mode") != PrecisionMode.a16_w16
            and cfg.get("precision_mode") != PrecisionMode.a16_w16_a16
            and cfg.get("precision_mode") is not None
        ):
            raise AccelerasImplementationError(
                f'{lname} must operate at 16 bits, got {cfg.get("precision_mode")}.',
            )
        return

    def finalize_global_cfg(self, algo_config):
        compress_cfg = self._model_config.compression_params
        if compress_cfg.auto_4bit_weights_ratio > 0 and compress_cfg.auto_16bit_weights_ratio > 0:
            raise RuntimeError("auto4bit and auto16bit are mutually exclusive. should've been validated in config")
        if compress_cfg.auto_16bit_weights_ratio not in {0, 1}:
            raise RuntimeError("auto16bit ratio can either be 0 or 1. should've been validated in config")

    def _filter_empty_layer_cfg(self, cfg):
        for key in LayerPrecisionConfig.keys():
            if cfg.get(key, None) is None:
                cfg.pop(key, None)
        return cfg

    # region auto4bit

    def greedy_auto4bit(self):
        skipped_layers = self._get_skipped_layers()
        layers_by_weight = self._get_layers_weights()
        total_weight = sum(weight for (_, weight) in layers_by_weight)
        compression_goal = total_weight * self._model_config.compression_params.auto_4bit_weights_ratio
        weights_in_4bit = self._get_weights_with_n_bits(skipped_layers, bits=4)
        layers_to_compress = []
        for layer, weight in layers_by_weight:
            if compression_goal < weights_in_4bit:
                break
            if layer in skipped_layers:
                continue
            weights_in_4bit += weight
            layers_to_compress.append(layer)
        ratio_4bit = weights_in_4bit / total_weight
        cfg = self.get_algo_config()
        for layer in layers_to_compress:
            current_mode = cfg.layers[layer].precision_mode
            prec_elements = current_mode.value.split("_")
            prec_elements[1] = "w4"
            cfg.layers[layer].precision_mode = PrecisionMode("_".join(prec_elements))
            layer_weight = self._get_layer_weight(layer)
            self._logger.info(f"Assigning 4bit weights to layer {layer} with {(layer_weight / 1e3):.2f}k parameters")
        self._logger.info(f"Ratio of weights in 4bit is {ratio_4bit:.2f}")

    def _get_layers_weights(self):
        weight_per_layer = [
            (layer, self._get_layer_weight(layer)) for layer in self._model.layers if self._is_supported_layer(layer)
        ]
        weight_per_layer.sort(key=lambda x: x[1], reverse=True)
        return weight_per_layer

    def _get_skipped_layers(self):
        """
        Get the set of layers that should be skipped by auto precision algorithm.
        The skipped layers are:
            the first & last weighted layers in the model
            layers with explicit config from the model script
        """
        first_layers = self._search_layers(self._model.flow.input_nodes, self._is_supported_layer)
        last_layers = self._search_layers(self._model.flow.output_nodes, self._is_supported_layer, reverse=True)
        layers_with_explicit_cfg = self._layers_with_explicit_cfg()
        skipped_layers = set(first_layers) | set(last_layers) | set(layers_with_explicit_cfg)
        return skipped_layers

    def _search_layers(self, start_layers, stop_cond: callable, reverse=False):
        """
        Search layers that satisfy the stop condition using a BFS search

        Args:
            start_layers: the entry points of the bfs search
            stop_cond: a boolean function that checks if a layer matches the search
            reverse: boolean, if true the BFS scans the graph backwards

        Return:
            list of layers that satisfied the stop condition

        """
        bfs_queue = list(start_layers)
        cond_layers = list()
        bfs_handled = set()
        model_flow = self._model.flow
        while len(bfs_queue) != 0:
            curr_layer = bfs_queue.pop(0)
            if curr_layer in bfs_handled:
                continue
            bfs_handled.add(curr_layer)
            if stop_cond(curr_layer):
                cond_layers.append(curr_layer)
                continue
            if not reverse:
                curr_extension = model_flow.successors_sorted(curr_layer)
            else:
                curr_extension = model_flow.predecessors_sorted(curr_layer)
            bfs_queue.extend(curr_extension)
        return cond_layers

    def _get_weights_with_n_bits(self, layers, bits):
        weights_in_4bit = 0
        cfg = self.get_algo_config()
        for layer in layers:
            prec_mode = cfg.layers[layer].precision_mode.value
            layer_bits = int(prec_mode.split("_")[1][1:])
            if layer_bits == bits and self._is_supported_layer(layer):
                weights_in_4bit += self._get_layer_weight(layer)
        return weights_in_4bit

    def _layers_with_explicit_cfg(self):
        prec_cfg_by_layer = self._model_config.precision_config.layers
        layers_with_explicit_cfg = []
        for lname, layer_cfg in prec_cfg_by_layer.items():
            if layer_cfg.meta is not None and layer_cfg.meta.get("precision_mode"):
                layers_with_explicit_cfg.append(lname)
        return layers_with_explicit_cfg

    def _get_layer_weight(self, lname):
        return np.prod(self._model.layers[lname].kernel.shape)

    def _is_supported_layer(self, lname):
        return type(self._model.layers[lname]) in self.SUPPORTED_AUTO4BIT

    # endregion

    # region auto16bit

    def auto16bit(self):
        cfg = self.get_algo_config()
        for lname in self._model.layers:
            cfg.layers[lname].precision_mode = PrecisionMode.a16_w16

    # endregion

    def fix_precision_config_hw_support_and_partial_verify(self):
        cfg = self.get_algo_config()
        unsupported_config = []
        for lname, layer in self._model.layers.items():
            if not isinstance(layer, BaseHailoLayer):
                continue
            lcfg = cfg.layers[lname]
            is_hw_supported = layer.is_supported_by_hw(cfg.target, lcfg)
            if is_hw_supported:
                continue
            if lcfg.precision_mode.value.split("_")[1] == "w16":
                fixed = self._fix_16bit_kernel_bias_mode(lcfg, lname)
                if not fixed:
                    unsupported_config.append((lname, lcfg))
            elif lcfg.precision_mode.value.split("_")[1] == "w4":
                fixed = self._fix_4bit_kernel_bias_mode(lcfg, lname)
                if not fixed:
                    unsupported_config.append((lname, lcfg))
            else:
                unsupported_config.append((lname, lcfg))
        if unsupported_config:
            msg = ["Precision config is not supported in the following cases: "]
            for lname, lcfg in unsupported_config:
                msg.append(f"    Layer: {lname}, Config: {lcfg.raw_dict()}")
                for key in lcfg.keys():
                    if lcfg.meta is not None and lcfg.meta.get(key) is not None:
                        msg.append(f"        {key} - {lcfg.meta[key].command}")
                    else:
                        msg.append(f"        {key} - default value")
            raise AccelerasImplementationError("\n".join(msg))

    def _fix_16bit_kernel_bias_mode(self, layer_cfg: LayerPrecisionConfig, lname):
        if layer_cfg.meta is None:
            has_explicit_bias = False
        else:
            has_explicit_bias = (layer_cfg.meta.get("bias_mode") is not None) and (
                not layer_cfg.meta["bias_mode"].is_glob
            )
        if has_explicit_bias:
            return False
        if layer_cfg.bias_mode == BiasMode.double_scale_initialization:
            layer_cfg.bias_mode = BiasMode.single_scale_decomposition
        else:
            layer_cfg.bias_mode = BiasMode.double_scale_initialization
        layer = self._model.layers[lname]
        is_hw_supported = layer.is_supported_by_hw(DEFAULT_OPTIMIZATION_TARGET, layer_cfg)
        if not is_hw_supported:
            return False
        return True

    def _fix_4bit_kernel_bias_mode(self, layer_cfg: LayerPrecisionConfig, lname):
        if layer_cfg.meta is None:
            has_explicit_bias = False
        else:
            has_explicit_bias = (layer_cfg.meta.get("bias_mode") is not None) and (
                not layer_cfg.meta["bias_mode"].is_glob
            )
        if has_explicit_bias:
            return False
        layer_cfg.bias_mode = BiasMode.double_scale_initialization
        layer = self._model.layers[lname]
        is_hw_supported = layer.is_supported_by_hw(DEFAULT_OPTIMIZATION_TARGET, layer_cfg)
        if not is_hw_supported:
            return False
        return True
