from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_crosscorrelation_dw import HailoCrossCorrelationDW
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_splitter import HailoFeatureSplitter
from hailo_model_optimization.acceleras.hailo_layers.hailo_width_splitter import HailoWidthSplitter
from hailo_model_optimization.acceleras.hailo_layers.op_factories import gen_acceleras_layers_from_hn
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputQuantElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerFeaturePolicy
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import SplitEWMultByBitSignificanceError
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class SplitEWMultByBitSignificance(OptimizationAlgorithm):
    def __init__(self, model: HailoModel, model_config, logger_level, dataset, **kwargs) -> None:
        super().__init__(
            model,
            model_config,
            logger_level=logger_level,
            name="Split Element-Wise Mult By Bit Significance",
            **kwargs,
        )
        self._unbatched_dataset = dataset
        self._layers_to_remove = []
        self._first_level_splits = []
        self._second_level_splits = []

    def _setup(self) -> None:
        self._layers_to_remove = []
        self._first_level_splits = []
        self._second_level_splits = []

    def should_skip_algo(self):
        return not self.get_algo_config().layers

    def get_algo_config(self):
        return self._model_config.split_ew_mult_by_bit_significance

    def _run_int(self) -> None:
        algo_cfg = self.get_algo_config()
        for layer_name in list(self._model.flow.toposort()):
            if layer_name not in algo_cfg.layers:
                continue

            num_splits = algo_cfg.layers[layer_name].num_splits
            ew_mult = self._model.layers[layer_name]
            self._split_single_ew_mult(ew_mult, num_splits)

        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level,
            logger=self._logger,
        )
        algo.run()

        stats_collector = StatsCollector(
            self._model,
            self._model_config,
            self._logger_level,
            self._unbatched_dataset,
            logger=self._logger,
        )
        self._create_splits(stats_collector, self._first_level_splits)
        self._create_splits(stats_collector, self._second_level_splits)

    def finalize_global_cfg(self, algo_config) -> None:
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        layer = self._model.layers[lname]
        if not isinstance(layer, (HailoElementwiseMult, HailoCrossCorrelationDW)):
            cfg = {}
        return cfg

    def _split_single_ew_mult(self, ew_mult, num_splits) -> None:
        other_pred, precision_change, ew_mult_pred = None, None, None
        for pred in self._model.flow.predecessors(ew_mult.full_name):
            if pred.split("/")[1].startswith(("precision_change", "ew_add_out_")):
                precision_change = self._model.layers[pred]
                ew_mult_pred = (
                    self._model.layers[next(iter(self._model.flow.predecessors(pred)))]
                    if pred.split("/")[1].startswith("precision_change")
                    else precision_change
                )
            else:
                other_pred = self._model.layers[pred]

        if any(layer is None for layer in [other_pred, precision_change, ew_mult_pred]):
            msg = f"Invalid predecessors for layer {ew_mult.full_name}"
            raise SplitEWMultByBitSignificanceError(msg)

        if num_splits == 2:
            ew_add_out = self._split_to_two(ew_mult, ew_mult_pred, other_pred, precision_change)
        elif num_splits == 3:
            ew_add_out = self._split_to_three(ew_mult, ew_mult_pred, other_pred)
        else:
            msg = f"Invalid num_splits={num_splits} for layer {ew_mult.full_name}"
            raise SplitEWMultByBitSignificanceError(msg)

        for succ_name in self._model.flow.successors_sorted(ew_mult.full_name):
            succ = self._model.layers[succ_name]
            if (
                succ_name.split("/")[1].startswith("precision_change")
                and len(list(self._model.flow.successors(succ_name))) == 1
            ):
                self._layers_to_remove.append(succ)
                precision_change_succ = next(iter(self._model.flow.successors(succ_name)))
                output_layer = self._model.layers[precision_change_succ]
                input_index = self._model.flow.get_edge_input_index(succ_name, precision_change_succ)
            else:
                output_layer = succ
                input_index = self._model.flow.get_edge_input_index(ew_mult.full_name, output_layer.full_name)
            output_index = self._model.flow.get_edge_output_index(ew_mult.full_name, succ_name)
            self._model.flow.add_edge(
                ew_add_out.full_name,
                output_layer.full_name,
                input_index=input_index,
                output_index=output_index,
            )

        self._layers_to_remove.append(ew_mult)
        if len(list(self._model.flow.successors(precision_change.full_name))) == 1 and precision_change.name.startswith(
            "precision_change"
        ):
            self._layers_to_remove.append(precision_change)

        for layer in self._layers_to_remove:
            self._remove_layer(layer)

        self._layers_to_remove = []

    def _create_splits(self, stats_collector, split_layers) -> None:
        if split_layers:
            stats_collector.run()
            for split_layer in split_layers:
                cfg = LayerTranslationConfig.get_default()
                cfg.activation_symmetric_range = LayerFeaturePolicy.disabled
                self._model_config.translation_config.layers[split_layer.full_name] = cfg
                split_layer.create_splits(translation_config=cfg)

    def _split_to_two(self, ew_mult, precision_change_pred, other_pred, precision_change):
        ew_mult_scope_name, ew_mult_name = ew_mult.full_name.split("/")
        ew_mult_hn = self._get_ew_mult_hn(ew_mult)
        layer_type = ew_mult_hn["type"]

        output_index = self._model.flow.get_edge_output_index(
            precision_change_pred.full_name, precision_change.full_name
        )
        precision_split_layer = self._add_precision_split(
            f"precision_split_{ew_mult_name}",
            [f"{layer_type}_low_{ew_mult_name}", f"{layer_type}_high_{ew_mult_name}"],
            precision_change_pred,
            output_indices=[output_index],
        )
        self._first_level_splits.append(precision_split_layer)

        weights = ew_mult.export_weights()

        ew_mult_low = self._add_ew_mult_low(
            other_pred,
            ew_mult_scope_name,
            f"{layer_type}_low_{ew_mult_name}",
            ew_mult_hn,
            precision_split_layer,
            weights=weights,
        )

        if isinstance(ew_mult, HailoCrossCorrelationDW):
            mul_high_by_8 = self._add_precision_change_8_to_16(
                f"mul_high_by_8_{ew_mult_name}",
                [f"{layer_type}_high_{ew_mult_name}"],
                precision_split_layer,
                output_index=1,
                precision_mode="a8_w8_a8",
            )
            mul_high_by_8.forced_output_scale_scalar_dof = 1 / 8
            ew_mult_high_pred = mul_high_by_8
        else:
            ew_mult_high_pred = precision_split_layer

        ew_mult_high = self._add_ew_mult_high(
            other_pred, ew_mult_scope_name, ew_mult_name, ew_mult_hn, ew_mult_high_pred, weights=weights
        )

        self._set_output_external_element(precision_change_pred, precision_split_layer)

        if isinstance(precision_change_pred, (HailoFeatureSplitter, HailoWidthSplitter)):
            pred = self._model.layers[next(iter(self._model.flow.predecessors(precision_change_pred.full_name)))]
            if isinstance(pred, HailoConcat):
                for pred_name in self._model.flow.predecessors(pred.full_name):
                    self._set_output_external_element(self._model.layers[pred_name], pred)
            self._set_output_external_element(pred, precision_change_pred)

        precision_mode = "a16_w16_a16" if isinstance(ew_mult, HailoCrossCorrelationDW) else "a8_w8_a16"

        return self._add_ew_add(
            f"ew_add_out_{ew_mult_name}",
            list(self._model.flow.successors_sorted(ew_mult.full_name)),
            [ew_mult_high, ew_mult_low],
            weights=weights,
            activation=ew_mult_hn["params"]["activation"],
            precision_mode=precision_mode,
        )

    def _split_to_three(self, ew_mult, precision_change_pred, other_pred):
        ew_mult_scope_name, ew_mult_name = ew_mult.full_name.split("/")
        ew_mult_hn = self._get_ew_mult_hn(ew_mult)
        layer_type = ew_mult_hn["type"]

        high_low_split_layer = self._add_precision_split(
            f"precision_split1_{ew_mult_name}",
            [f"{layer_type}_low_{ew_mult_name}", f"precision_change1_{ew_mult_name}"],
            precision_change_pred,
        )
        self._first_level_splits.append(high_low_split_layer)

        weights = ew_mult.export_weights()

        ew_mult_low = self._add_ew_mult_low(
            other_pred,
            ew_mult_scope_name,
            f"{layer_type}_low_{ew_mult_name}",
            ew_mult_hn,
            high_low_split_layer,
            weights=weights,
        )

        precision_change_before_high_mid_split = self._add_precision_change_8_to_16(
            f"precision_change1_{ew_mult_name}",
            [f"precision_split2_{ew_mult_name}"],
            high_low_split_layer,
            output_index=1,
        )

        high_mid_split_layer = self._add_precision_split(
            f"precision_split2_{ew_mult_name}",
            [f"{layer_type}_mid_{ew_mult_name}", f"{layer_type}_high_{ew_mult_name}"],
            precision_change_before_high_mid_split,
            output_indices=[],
        )
        self._second_level_splits.append(high_mid_split_layer)

        self._set_output_external_element(precision_change_before_high_mid_split, high_mid_split_layer)

        weights = ew_mult.export_weights()

        ew_mult_mid = self._add_ew_mult_low(
            other_pred,
            ew_mult_scope_name,
            f"{layer_type}_mid_{ew_mult_name}",
            ew_mult_hn,
            high_mid_split_layer,
            weights=weights,
        )

        ew_mult_high = self._add_ew_mult_high(
            other_pred, ew_mult_scope_name, ew_mult_name, ew_mult_hn, high_mid_split_layer, weights=weights
        )

        ew_add_high_mid = self._add_ew_add(
            f"ew_add_high_mid_{ew_mult_name}",
            [f"ew_add_out_{ew_mult_name}"],
            [ew_mult_high, ew_mult_mid],
        )

        precision_change_before_add = self._add_precision_change_8_to_16(
            f"precision_change2_{ew_mult_name}",
            [f"ew_add_out_{ew_mult_name}"],
            ew_mult_low,
        )

        return self._add_ew_add(
            f"ew_add_out_{ew_mult_name}",
            list(self._model.flow.successors_sorted(ew_mult.full_name)),
            [ew_add_high_mid, precision_change_before_add],
            precision_mode="a16_w16_a16",
            bias_mode="double_scale_initialization",
            weights=weights,
            activation=ew_mult_hn["params"]["activation"],
        )

    def _set_output_external_element(self, input_layer, output_layer, bits=12):
        lossy_element = APUOutputQuantElement(bits=bits)
        input_layer.output_lossy_element_external = lossy_element
        output_layer.input_lossy_element_external = lossy_element

    def _add_ew_mult_high(self, other_pred, ew_mult_scope_name, ew_mult_name, ew_mult_hn, split_layer, weights):
        layer_type = ew_mult_hn["type"]
        ew_mult_high = self._add_generic_layer(
            ew_mult_hn,
            ew_mult_scope_name,
            f"{layer_type}_high_{ew_mult_name}",
            [split_layer, other_pred],
            output_indices=[1, 1],
            weights=weights,
        )
        if layer_type == "ew_mult":
            ew_mult_high.mock_kernel_values = [16, 16]
            ew_mult_high.forced_output_scale_scalar_dof = 1
        else:
            ew_mult_high.forced_output_scale_scalar_dof = 8
        self._add_precision_config(ew_mult_high, bias_mode="double_scale_initialization", precision_mode="a8_w8")
        return ew_mult_high

    def _add_ew_mult_low(self, other_pred, ew_mult_scope_name, layer_name, ew_mult_hn, split_layer, weights):
        layer_type = ew_mult_hn["type"]
        ew_mult_split = self._add_generic_layer(
            ew_mult_hn,
            ew_mult_scope_name,
            layer_name,
            [split_layer, other_pred],
            output_indices=[0, 0],
            weights=weights,
        )
        if layer_type == "ew_mult":
            ew_mult_split.mock_kernel_values = [2, 16]
            ew_mult_split.forced_output_scale_scalar_dof = 2**6
        else:
            ew_mult_split.forced_output_scale_scalar_dof = 8
        self._add_precision_config(ew_mult_split, bias_mode="double_scale_initialization", precision_mode="a8_w8")
        return ew_mult_split

    def _add_precision_split(self, layer_name, succ_names, pred, output_indices=None):
        scope_name = pred.full_name.split("/")[0]
        pred_hn_element = pred.to_hn()
        shape = pred_hn_element["output_shapes"][0]
        hn_element = {
            "type": "precision_splitter",
            "input": [pred.full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape],
            "output_shapes": [shape, shape],
        }

        new_layer = self._add_generic_layer(hn_element, scope_name, layer_name, [pred], output_indices=output_indices)
        self._add_precision_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16")
        return new_layer

    def _add_precision_change_8_to_16(
        self, layer_name, succ_names, pred, output_index=0, ltype="activation", precision_mode="a8_w8_a16"
    ):
        scope_name = pred.full_name.split("/")[0]
        pred_hn_element = pred.to_hn()
        shape = pred_hn_element["output_shapes"][0]
        hn_element = {
            "type": ltype,
            "input": [pred.full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape],
            "output_shapes": [shape],
            "params": {"activation": "linear"},
        }
        new_layer = self._add_generic_layer(hn_element, scope_name, layer_name, [pred], output_indices=[output_index])
        self._add_precision_config(new_layer, bias_mode="single_scale_decomposition", precision_mode=precision_mode)
        return new_layer

    def _get_ew_mult_hn(self, ew_mult):
        hn_element = ew_mult.to_hn()
        hn_element["quantization_params"] = {}
        return hn_element

    def _add_ew_add(
        self,
        new_name,
        succ_names,
        preds,
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8_a16",
        ltype="ew_add",
        output_indices=None,
        activation="linear",
        weights=None,
    ):
        output_indices = [0, 0] if output_indices is None else output_indices
        scope_name = preds[0].full_name.split("/")[0]
        hn_element = preds[0].to_hn()
        shape = hn_element["output_shapes"][0]
        hn_element = {
            "type": ltype,
            "input": [preds[0].full_name, preds[1].full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape, shape],
            "output_shapes": [shape] * len(succ_names),
            "compilation_params": {},
            "quantization_params": {},
            "params": {"activation": activation},
        }
        if ltype == "normalization":
            hn_element["params"]["elementwise_add"] = True

        new_layer = self._add_generic_layer(
            hn_element, scope_name, new_name, preds, output_indices=output_indices, weights=weights
        )
        self._add_precision_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def _add_generic_layer(self, hn_element, scope_name, new_name, preds, weights=None, output_indices=None):
        layer_name = f"{scope_name}/{new_name}"
        new_layer = gen_acceleras_layers_from_hn(layer_name, hn_element, self.optimization_target)[layer_name]

        if weights is None:
            weights = {}
        if hn_element["type"] != "normalization":
            new_layer.import_weights(weights)

        self._add_preds(new_layer, preds, output_indices=output_indices)

        return new_layer

    def _add_preds(self, new_layer, predecessors, output_indices=None) -> None:
        output_indices = output_indices if output_indices else [0] * len(predecessors)
        self._model.layers[new_layer.full_name] = new_layer
        node = new_layer.full_name
        self._model.flow.add_node(node)

        for i, predecessor in enumerate(predecessors):
            self._model.flow.add_edge(predecessor.full_name, node, input_index=i, output_index=output_indices[i])

    def _remove_layer(self, layer) -> None:
        self._model.layers.pop(layer.full_name, None)
        self._model.flow.remove_node(layer.full_name)
        self._model_config.remove_layer_from_all_configs(layer.full_name)

    def _add_precision_config(
        self, new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16"
    ) -> None:
        layer_name = new_layer.full_name
        cfg = LayerPrecisionConfig(precision_mode=precision_mode, bias_mode=bias_mode, quantization_groups=1)
        self._model_config.precision_config.layers[layer_name] = cfg
        new_layer.import_precision_config(cfg, self.optimization_target)
