from hailo_model_optimization.acceleras.hailo_layers.hailo_avgpool_v2 import HailoAvgPool
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerType, PaddingType
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class GlobalAvgpoolReduction(OptimizationAlgorithm):
    NAME = "global_avgpool_reduction"
    MAX_ACCUMULATOR_BUFFER = 2**16

    def __init__(self, model: HailoModel, model_config, logger_level, dataset, **kwargs):
        super().__init__(model, model_config, "Global Avgpool Reduction", logger_level, **kwargs)
        self.dataset = dataset
        self.has_reduced = False

    def get_algo_config(self):
        return self._model_config.global_avgpool_reduction

    def _setup(self):
        pass

    def _run_int(self):
        self._reduce_global_avgpool()
        if self.has_reduced:
            stats_collector = StatsCollector(
                self._model,
                self._model_config,
                self._logger_level,
                self.dataset,
                logger=self._logger,
            )
            stats_collector.run()

    def should_skip_algo(self):
        return False

    def log_config(self):
        pass

    def _reduce_global_avgpool(self):
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            should_reduce, division_factors = self._should_reduce_layer(layer)
            if should_reduce:
                in_shape = layer.input_shapes[0]
                orig_spatial = in_shape[1:3]
                succs = list(self._model.flow.successors(layer.full_name))
                preds = list(self._model.flow.predecessors(layer.full_name))
                reducing_kernel = [1, *[spatial // fact for spatial, fact in zip(orig_spatial, division_factors)], 1]
                reduced_kernel = [1, *division_factors, 1]
                new_in_shape = [-1, *division_factors, in_shape[-1]]
                reducing_name = f'{self._model.model_name}/reducing_{lname.split("/")[-1]}'
                self._add_avgpool_layer(preds[0], lname, reducing_kernel, in_shape, reducing_name, create_cfg=True)
                self._model.remove_layer(layer)
                self._add_avgpool_layer(reducing_name, succs[0], reduced_kernel, new_in_shape, lname, create_cfg=False)
                self.has_reduced = True

    def _should_reduce_layer(self, layer):
        if not isinstance(layer, HailoAvgPool) or not layer.is_global_avgpool():
            return False, None

        algo_cfg = self.get_algo_config()
        division_factors = (1, 1)
        orig_spatial = layer.input_shapes[0][1:3]
        if orig_spatial[0] * orig_spatial[1] > self.MAX_ACCUMULATOR_BUFFER:
            division_factors = (orig_spatial[0], 1)
        if layer.full_name in algo_cfg.layers.keys():
            division_factors = algo_cfg.layers[layer.full_name].division_factors
        if division_factors == (1, 1):
            return False, None

        succs = list(self._model.flow.successors(layer.full_name))
        preds = list(self._model.flow.predecessors(layer.full_name))
        if len(preds) != 1 or len(succs) != 1:
            return False, None

        if any(spatial % factor != 0 for spatial, factor in zip(orig_spatial, division_factors)):
            self._logger.warning(
                f"Can't reduce {layer.full_name} with spatial dimensions {orig_spatial} by "
                f"{division_factors}. Please choose different factors.",
            )
            return False, None
        return True, division_factors

    def _add_avgpool_layer(self, source, target, kernel_shape, input_shape, avgpool_name, create_cfg):
        hn, params = {}, {}
        hn["type"] = LayerType.AVGPOOL.value
        hn["input"] = [source]
        hn["output"] = [target]
        hn["input_shapes"] = [[-1, *input_shape[1:]]]
        spatial_out = [spatial // kernel for spatial, kernel in zip(input_shape[1:3], kernel_shape[1:3])]
        hn["output_shapes"] = [[-1, *spatial_out, input_shape[-1]]]
        params["strides"] = kernel_shape
        params["kernel_shape"] = kernel_shape
        params["padding"] = PaddingType.VALID.value
        params["elementwise_add"] = False
        hn["params"] = params
        avgpool_layer = HailoAvgPool.from_hn(avgpool_name, hn)
        self._model.add_layer(avgpool_layer, [(source, target)])

        if create_cfg:
            orig_cfg = self._model_config.precision_config.layers[target]
            orig_prec_mode = orig_cfg.precision_mode.value.split("_")
            curr_precision_mode = "_".join([orig_prec_mode[0], orig_prec_mode[1], orig_prec_mode[0]])
            avgpool_config = LayerPrecisionConfig(
                precision_mode=curr_precision_mode,
                bias_mode=orig_cfg.bias_mode,
                quantization_groups=orig_cfg.quantization_groups,
            )
            self._model_config.precision_config.layers[avgpool_name] = avgpool_config
            avgpool_layer.import_precision_config(avgpool_config, self.optimization_target)
            self._model_config.translation_config.layers[avgpool_name] = LayerTranslationConfig()

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        layer = self._model.layers[lname]
        if not isinstance(layer, HailoAvgPool) or not layer.is_global_avgpool():
            cfg = {}
        return cfg
