from hailo_model_optimization.acceleras.hailo_layers.hailo_depth_to_space import HailoDepthToSpace
from hailo_model_optimization.acceleras.hailo_layers.op_factories import gen_acceleras_layers_from_hn
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputQuantElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeaturePolicy, LayerFeaturePolicy
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class Output16BitAs8Bit(OptimizationAlgorithm):
    def __init__(self, model: HailoModel, model_config, logger_level, dataset, **kwargs) -> None:
        """
        Change 16-bit output with 1 channel to 8-bit output with 2 channels
        """
        super().__init__(
            model,
            model_config,
            logger_level=logger_level,
            name="Output 16bits as 8bits",
            **kwargs,
        )
        self._unbatched_dataset = dataset
        self._first_level_splits = []
        self._second_level_splits = []

    def _setup(self) -> None:
        self._first_level_splits = []
        self._second_level_splits = []

    def should_skip_algo(self):
        return self.get_algo_config().output_16bit_as_8bit == FeaturePolicy.disabled

    def get_algo_config(self):
        return self._model_config.globals

    def _run_int(self) -> None:
        self._output_16it_as_8bit()
        self._split_d2s_before_precision_split()

        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level,
            logger=self._logger,
        )
        algo.run()

        stats_collector = StatsCollector(
            self._model,
            self._model_config,
            self._logger_level,
            self._unbatched_dataset,
            logger=self._logger,
        )
        self._create_splits(stats_collector, self._first_level_splits)
        self._create_splits(stats_collector, self._second_level_splits)

    def finalize_global_cfg(self, algo_config) -> None:
        pass

    def _create_splits(self, stats_collector, split_layers) -> None:
        if split_layers:
            stats_collector.run()
            for split_layer in split_layers:
                cfg = self._model_config.translation_config.layers.get(
                    split_layer.full_name, LayerTranslationConfig.get_default()
                )
                cfg.activation_symmetric_range = LayerFeaturePolicy.disabled
                self._model_config.translation_config.layers[split_layer.full_name] = cfg
                split_layer.create_splits(translation_config=cfg)

    def _set_output_external_element(self, input_layer, output_layer, bits=12):
        lossy_element = APUOutputQuantElement(bits=bits)
        input_layer.output_lossy_element_external = lossy_element
        output_layer.input_lossy_element_external = lossy_element
        cfg = self._model_config.translation_config.layers.get(
            output_layer.full_name, LayerTranslationConfig.get_default()
        )
        cfg.activation_symmetric_range = LayerFeaturePolicy.disabled
        self._model_config.translation_config.layers[output_layer.full_name] = cfg

    def _add_precision_split(self, layer_name, succ_names, pred):
        scope_name = pred.full_name.split("/")[0]
        pred_hn_element = pred.to_hn()
        shape = pred_hn_element["output_shapes"][0]
        hn_element = {
            "type": "precision_splitter",
            "input": [pred.full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape],
            "output_shapes": [shape, shape],
        }
        new_layer = self._add_generic_layer(hn_element, scope_name, layer_name, [pred])
        self._add_precision_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16")
        return new_layer

    def _add_precision_change_8_to_16(self, layer_name, succ_names, pred, output_index=0, ltype="activation"):
        scope_name = pred.full_name.split("/")[0]
        pred_hn_element = pred.to_hn()
        shape = pred_hn_element["output_shapes"][0]
        hn_element = {
            "type": ltype,
            "input": [pred.full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape],
            "output_shapes": [shape],
            "params": {"activation": "linear"},
        }
        new_layer = self._add_generic_layer(hn_element, scope_name, layer_name, [pred], output_indices=[output_index])
        self._add_precision_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a8_w8_a16")
        return new_layer

    def _add_ew_add(
        self,
        new_name,
        succ_names,
        preds,
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8_a16",
        ltype="ew_add",
        output_indices=None,
    ):
        output_indices = [0, 0] if output_indices is None else output_indices
        scope_name = preds[0].full_name.split("/")[0]
        hn_element = preds[0].to_hn()
        shape = hn_element["output_shapes"][0]
        hn_element = {
            "type": ltype,
            "input": [preds[0].full_name, preds[1].full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape, shape],
            "output_shapes": [shape] * len(succ_names),
            "compilation_params": {},
            "quantization_params": {},
            "params": {"activation": "linear"},
        }
        if ltype == "normalization":
            hn_element["params"]["elementwise_add"] = True

        new_layer = self._add_generic_layer(hn_element, scope_name, new_name, preds, output_indices=output_indices)
        self._add_precision_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def _add_concat(self, new_name, succ_names, preds):
        scope_name = preds[0].full_name.split("/")[0]
        hn_element = preds[0].to_hn()
        shape = hn_element["output_shapes"][0]
        output_shape = [*shape[:-1], shape[-1] * 2]
        hn_element = {
            "type": "concat",
            "input": [preds[0].full_name, preds[1].full_name],
            "output": [f"{scope_name}/{output_name}" for output_name in succ_names],
            "input_shapes": [shape, shape],
            "output_shapes": [output_shape] * len(succ_names),
        }

        return self._add_generic_layer(hn_element, scope_name, new_name, preds)

    def _add_generic_layer(self, hn_element, scope_name, new_name, preds, weights=None, output_indices=None):
        layer_name = f"{scope_name}/{new_name}"
        new_layer = gen_acceleras_layers_from_hn(layer_name, hn_element, self.optimization_target)[layer_name]

        if weights is None:
            weights = {}
        if hn_element["type"] != "normalization":
            new_layer.import_weights(weights)

        self._add_preds(new_layer, preds, output_indices=output_indices)

        return new_layer

    def _add_preds(self, new_layer, predecessors, output_indices=None) -> None:
        self._model.layers[new_layer.full_name] = new_layer
        node = new_layer.full_name
        self._model.flow.add_node(node)
        output_indices = output_indices if output_indices else [0] * len(predecessors)

        for i, predecessor in enumerate(predecessors):
            self._model.flow.add_edge(predecessor.full_name, node, input_index=i, output_index=output_indices[i])

    def _add_precision_config(
        self, new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16"
    ) -> None:
        layer_name = new_layer.full_name
        cfg = LayerPrecisionConfig(precision_mode=precision_mode, bias_mode=bias_mode, quantization_groups=1)
        self._model_config.precision_config.layers[layer_name] = cfg
        new_layer.import_precision_config(cfg, self.optimization_target)

    def _output_16it_as_8bit(self):
        for output_lname in self._model.flow.output_nodes:
            real_output_lname = self._model.flow.predecessors_sorted(output_lname)[0]
            real_output_layer = self._model.layers[real_output_lname]
            real_output_base_name = real_output_layer.full_name.split("/")[1]

            precision_split_layer = self._add_precision_split(
                f"precision_split_{real_output_base_name}",
                [
                    f"precision_change_lsb_{real_output_base_name}",
                    f"ew_add_msb_{real_output_base_name}",
                ],
                real_output_layer,
            )
            self._first_level_splits.append(precision_split_layer)

            precision_change_lsb = self._add_precision_change_8_to_16(
                f"precision_change_lsb_{real_output_base_name}",
                [f"precision_split_lsb_{real_output_base_name}"],
                precision_split_layer,
                output_index=0,
            )

            precision_split_lsb_layer = self._add_precision_split(
                f"precision_split_lsb_{real_output_base_name}",
                ["concat_out", f"ew_add_msb_{real_output_base_name}"],
                precision_change_lsb,
            )
            self._second_level_splits.append(precision_split_lsb_layer)

            self._set_output_external_element(precision_change_lsb, precision_split_lsb_layer, bits=9)

            ew_add_msb = self._add_ew_add(
                f"ew_add_msb_{real_output_base_name}",
                ["concat_out"],
                [precision_split_layer, precision_split_lsb_layer],
                precision_mode="a8_w8_a8",
                output_indices=[1, 1],
            )
            self._set_output_external_element(precision_split_layer, ew_add_msb, bits=1)

            concat = self._add_concat(
                "concat_out",
                [real_output_base_name],
                [precision_split_lsb_layer, ew_add_msb],
            )

            output_layer = self._model.layers[output_lname]
            self._add_precision_config(output_layer, bias_mode="single_scale_decomposition", precision_mode="a8_w8")

            input_index = self._model.flow.get_edge_input_index(real_output_lname, output_lname)
            output_index = self._model.flow.get_edge_output_index(real_output_lname, output_lname)
            self._model.flow.add_edge(
                concat.full_name, output_lname, input_index=input_index, output_index=output_index
            )
            self._model.flow.remove_edge(real_output_lname, output_lname)

    def _split_d2s_before_precision_split(self):
        for split_layer in self._first_level_splits:
            split_layer_name = split_layer.full_name
            pred = self._model.layers[self._model.flow.predecessors_sorted(split_layer_name)[0]]
            if not isinstance(pred, HailoDepthToSpace):
                continue

            d2s_layer = pred

            split_layer_hn = split_layer.to_hn()
            d2s_hn = d2s_layer.to_hn()
            split_layer_hn["input_shapes"] = d2s_hn["input_shapes"][:]
            split_layer_hn["output_shapes"] = [d2s_hn["input_shapes"][0]] * 2
            split_layer._hn_element = split_layer_hn

            self._model.remove_layer(d2s_layer)
            self._model_config.remove_layer_from_all_configs(d2s_layer.full_name)

            for i, succ_name in enumerate(self._model.flow.successors_sorted(split_layer_name)):
                layer_name = f"{d2s_layer.full_name}_{i+1}"
                new_layer = gen_acceleras_layers_from_hn(layer_name, d2s_hn, self.optimization_target)[layer_name]
                self._model.add_layer(new_layer, [(split_layer_name, succ_name)])
