from hailo_model_optimization.acceleras.hailo_layers.hailo_format_conversion import HailoFormatConversion
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_bilinear_mac import HailoResizeBilinearMac
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerType,
    ResizeBilinearPixelsMode,
    ResizeMethod,
    ResolutionReductionInterpolationMode,
    ResolutionReductionStage,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import InvalidInputShape, ResolutionReductionError
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class ResolutionReduction(OptimizationAlgorithm):
    """
    Resolution reduction for models with high resolution that can't be quantized.
    This algorithm is allowed only to models that don't have shape-dependant weights (like dense layer).
    """

    def __init__(self, model: HailoModel, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, "Resolution Reduction", logger_level, **kwargs)
        self._should_run = True
        self._inputs_to_cfg_params = {}
        self._reduction_stage = kwargs.get("reduction_stage", ResolutionReductionStage.apply)
        self._original_input_shapes = kwargs.get("original_input_shapes", {})

    def _setup(self):
        if self._reduction_stage == ResolutionReductionStage.apply:
            # TODO SDK-SDK-48985 why the logic is on the model.
            self._should_run, _ = self._model.resolution_reduction_prepare()
        else:
            for layer in self._model.layers.values():
                if layer.in_emulation_graph:
                    layer.ignore_io_shapes_verification = False
        if self._should_run:
            self._set_inputs_to_cfg_params()

    def should_skip_algo(self):
        cfg = self.get_algo_config()
        if not cfg.layers:
            return self.get_algo_config().shape is None

        return any(layer_cfg.shape is None for layer_cfg in cfg.layers.values())

    def get_algo_config(self):
        return self._model_config.resolution_reduction

    def log_config(self):
        pass

    def _run_int(self):
        if not self._should_run:
            return
        if self._reduction_stage == ResolutionReductionStage.apply:
            self._apply_reduction()
        elif self._reduction_stage == ResolutionReductionStage.revert:
            self._revert_reduction()

    def _apply_reduction(self):
        try:
            for idx, input_lname in enumerate(self._inputs_to_cfg_params):
                input_cfg = self._inputs_to_cfg_params[input_lname]
                if input_cfg["interpolation"] == ResolutionReductionInterpolationMode.disabled:
                    self._apply_model_input_shapes_to_layer(input_lname)
                else:
                    self._add_resize_to_input_layer(input_lname, idx + 1)
            applied_shapes = self._get_current_input_shapes()
            self._model.compute_and_verify_output_shape(applied_shapes)
        except InvalidInputShape as err:
            cfg = self.get_algo_config()
            cfg_shape = cfg.shape if not cfg.layers else [x.shape for x in cfg.layers.values()]
            raise ResolutionReductionError(
                f"Resolution reduction to {cfg_shape} failed while validating "
                f"model shapes. Encountered layer {err.lname} with unfeasible input shapes. "
                "Please try again with a different input shape.",
            )

    def _revert_reduction(self):
        for lname, input_cfg in self._inputs_to_cfg_params.items():
            if input_cfg["interpolation"] == ResolutionReductionInterpolationMode.disabled:
                self._revert_model_input_shapes_from_layer(lname)
            else:
                self._remove_resize_from_input_layer(lname)
        self._model.compute_and_verify_output_shape(self._get_current_input_shapes())

    def _apply_model_input_shapes_to_layer(self, lname):
        cfg_shape = self._inputs_to_cfg_params[lname]["shape"]
        hn_element = self._model.layers[lname].hn_element
        hn_element["input_shapes"][0][1:3] = cfg_shape
        hn_element["output_shapes"][0][1:3] = cfg_shape
        self._model.layers[lname].set_input_spec(hn_element["input_shapes"])

    def _add_resize_to_input_layer(self, input_lname, index):
        source_lname = self._get_resize_source(input_lname)
        reduced_spatial = self._inputs_to_cfg_params[input_lname]["shape"]
        targets = self._model.flow.successors_sorted(source_lname)
        self._add_resize_layer(source_lname, targets, reduced_spatial, index)

    def _get_current_input_shapes(self):
        hn_elements = [self._model.layers[lname].hn_element for lname in self._model.flow.input_nodes]
        shapes = [(None,) + tuple(hn_element["input_shapes"][0][1:]) for hn_element in hn_elements]
        return shapes

    def _revert_model_input_shapes_from_layer(self, lname):
        hn_element = self._model.layers[lname].hn_element
        hn_element["input_shapes"] = self._original_input_shapes[lname]
        hn_element["output_shapes"] = self._original_input_shapes[lname]
        self._model.layers[lname].set_input_spec(hn_element["input_shapes"])

    def _remove_resize_from_input_layer(self, input_lname):
        source_lname = self._get_resize_source(input_lname)
        resize_lname = self._model.flow.successors_sorted(source_lname)[0]
        resize_layer = self._model.layers[resize_lname]
        self._model.remove_layer(resize_layer)

        # In case there's a non-emulated input conversion, we should ignore the input layer shapes verification
        # although the resolution reduction is done
        source_layer = self._model.layers[source_lname]
        if isinstance(source_layer, HailoFormatConversion) and source_layer.ignore_io_shapes_verification:
            input_lname = self._model.flow.predecessors_sorted(source_lname)[0]
            input_layer = self._model.layers[input_lname]
            input_layer.ignore_io_shapes_verification = True

    def _set_inputs_to_cfg_params(self):
        cfg = self.get_algo_config()
        if cfg.layers:
            self._inputs_to_cfg_params = {
                lname: {"shape": layer_cfg.shape, "interpolation": layer_cfg.interpolation}
                for lname, layer_cfg in cfg.layers.items()
            }
        else:
            self._inputs_to_cfg_params = {
                lname: {"shape": cfg.shape, "interpolation": cfg.interpolation}
                for lname in self._model.flow.input_nodes
            }

    def _get_resize_source(self, source_lname):
        succs = list(self._model.flow.successors(source_lname))
        if len(succs) == 1:
            succ_layer = self._model.layers[succs[0]]
            if isinstance(succ_layer, HailoFormatConversion) and not succ_layer.in_emulation_graph:
                source_lname = succs[0]
        return source_lname

    def _add_resize_layer(self, source, targets, spatial_resize, resize_idx):
        hn, params = {"compilation_params": {}}, {}
        hn["type"] = LayerType.RESIZE.value
        hn["input"] = [source]
        hn["output"] = targets
        in_shape = list(self._model.layers[source].output_shapes[0])
        hn["input_shapes"] = [[-1] + in_shape[1:]]
        hn["output_shapes"] = [[-1, *spatial_resize, in_shape[-1]] for _ in targets]
        params["resize_h_ratio_list"] = [float(spatial_resize[0]) / float(in_shape[1])]
        params["resize_w_ratio_list"] = [float(spatial_resize[1]) / float(in_shape[2])]
        params["method"] = ResizeMethod.BILINEAR.value
        params["resize_bilinear_pixels_mode"] = ResizeBilinearPixelsMode.DISABLED.value
        hn["params"] = params
        hn["compilation_params"]["hw_layer_type_list"] = ["lcu"]  # just to pass layer validation, won't be used
        # add layer to model
        resize_name = f"{self._model.model_name}/resolution_reduction_{LayerType.RESIZE.value}{resize_idx}"
        resize_layer = HailoResizeBilinearMac.from_hn(lname=resize_name, hn_element=hn)
        self._model.add_layer(resize_layer, [(source, target) for target in targets])
        resize_cfg = LayerPrecisionConfig(
            precision_mode="a8_w8_a8",
            bias_mode="double_scale_initialization",
            quantization_groups=1,
        )
        self._model_config.precision_config.layers[resize_name] = resize_cfg
        resize_layer.import_precision_config(resize_cfg, self.optimization_target)
        self._model_config.translation_config.layers[resize_name] = LayerTranslationConfig()

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        layer = self._model.layers[lname]
        if not isinstance(layer, HailoInputLayer):
            cfg = {}
        return cfg
