from enum import Enum

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PostprocessTarget, PostprocessType
from hailo_sdk_common.hailo_nn.hn_definitions import HWLayerType, LayerType, ResizeBilinearPixelsMode, ResizeMethod
from hailo_sdk_common.hailo_nn.hn_layers import PostprocessLayer, ResizeLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import MODEL_SCRIPT_LAYER_PREFIX
from hailo_sdk_common.logger.logger import default_logger


class ResizeLayersAdditionException(Exception):
    pass


class ResizePlacement(Enum):
    INPUT_LAYER = "input_layer"
    OUTPUT_LAYER = "output_layer"


class ResizeLayersAdder:
    """
    Representing object responsible for adding resize bi-linear layers to the inputs.

    Args:
        hailo_nn (:class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN`): The Hailo NN to add resize input layers.
        input_shapes (tuple or dict): The new input shape for each input layer.
            In case all input layers should have the same input shape use tuple of (h, w) and a resize
            layer from the given shape will be added to all input layers. Otherwise, use dict of shape tuple (h, w)
            by input names.
        resize_layers (list of str, optional): Names for the new layers to add.

    """

    def __init__(
        self,
        hailo_nn,
        resize_layers,
        resize_shapes,
        resize_method,
        pixels_mode=None,
        hw_layer_type=None,
        engine=None,
    ):
        self._hailo_nn = hailo_nn
        self._resize_placement = self._get_resize_placement(resize_shapes)
        self._layers_to_shapes = self._get_layers_to_shapes(resize_shapes)
        self._get_resize_names(resize_layers)
        self._get_resize_method(resize_method)
        self._get_pixels_mode(pixels_mode)
        self._get_hw_layer_type(hw_layer_type)
        self._get_engine(engine)
        self._graph_next_index = self._hailo_nn.get_next_index()
        self._logger = default_logger()

    @property
    def layers_to_shapes(self):
        return self._layers_to_shapes

    @property
    def pixels_mode(self):
        return self._pixels_mode

    @property
    def hw_layer_type(self):
        return self._hw_layer_type

    @property
    def resize_layers(self):
        return self._resize_layers

    @property
    def resize_method(self):
        return self._resize_method

    def _get_resize_method(self, resize_method):
        if not resize_method:
            self._resize_method = ResizeMethod.bilinear
        elif resize_method not in [item.value for item in ResizeMethod]:
            raise ResizeLayersAdditionException(f"Wrong resize_method given: {resize_method}")
        else:
            self._resize_method = ResizeMethod(resize_method)

    def _get_pixels_mode(self, pixels_mode):
        if not pixels_mode:
            self._pixels_mode = ResizeBilinearPixelsMode.align_corners
        elif pixels_mode not in [item.value for item in ResizeBilinearPixelsMode]:
            raise ResizeLayersAdditionException(f"Wrong pixels_mode given: {pixels_mode}")
        else:
            self._pixels_mode = ResizeBilinearPixelsMode(pixels_mode)

    def _get_hw_layer_type(self, hw_layer_type):
        if not hw_layer_type:
            self._hw_layer_type = HWLayerType.ppu
        elif hw_layer_type not in [item.value for item in HWLayerType]:
            raise ResizeLayersAdditionException(f"Wrong hw_layer_type given: {hw_layer_type}")
        else:
            self._hw_layer_type = HWLayerType(hw_layer_type)

    def _get_engine(self, engine):
        if not engine:
            self._engine = PostprocessTarget.NN_CORE
        elif engine not in [item.value for item in PostprocessTarget]:
            raise ResizeLayersAdditionException(f"Wrong engine given: {engine}")
        else:
            self._engine = PostprocessTarget(engine)

    def _get_resize_names(self, resize_layers):
        num_of_resize_to_add = len(self._layers_to_shapes)
        if not resize_layers:
            self._resize_layers = [f"resize_input{index + 1}" for index in range(num_of_resize_to_add)]
        elif len(resize_layers) == num_of_resize_to_add:
            self._resize_layers = resize_layers
        else:
            raise ResizeLayersAdditionException(
                f"Given {len(resize_layers)} names for the new layers when "
                f"{num_of_resize_to_add} names are required",
            )

    def add_resize_layers(self):
        self._logger.verbose("Adding resize input layers")
        for index, layer in enumerate(self._layers_to_shapes):
            if self._engine == PostprocessTarget.NN_CORE:
                h_ratios = layer.input_shape[1] / float(self._layers_to_shapes[layer][1])
                w_ratios = layer.input_shape[2] / float(self._layers_to_shapes[layer][2])
                if self._resize_placement == ResizePlacement.OUTPUT_LAYER:
                    h_ratios = 1 / h_ratios
                    w_ratios = 1 / w_ratios
                resize_layer = self._add_resize_layer(index, layer, h_ratios=h_ratios, w_ratios=w_ratios)
            else:
                if layer not in self._hailo_nn.get_real_output_layers(False):
                    raise ResizeLayersAdditionException("Postprocess resize layers can be applied only on output layer")

                resize_layer = self._add_resize_postprocess_layer(index, layer)

            resize_layer.add_original_name(f"{MODEL_SCRIPT_LAYER_PREFIX}_{resize_layer.name_without_scope}")
            if self._resize_placement == ResizePlacement.OUTPUT_LAYER:
                output_order = self._hailo_nn.net_params.output_layers_order
                output_order[output_order.index(layer.name)] = resize_layer.name

        return self._hailo_nn

    def _add_resize_postprocess_layer(self, index, layer):
        postprocess_layer = PostprocessLayer()
        postprocess_layer.engine = PostprocessTarget.CPU
        postprocess_layer.op = LayerType.postprocess
        postprocess_layer.postprocess_type = PostprocessType.RESIZE
        postprocess_layer.resize_shape = self._layers_to_shapes[layer][1:3]
        postprocess_layer.resize_method = self._resize_method.value
        postprocess_layer.pixels_mode = self._pixels_mode.value
        postprocess_layer.name = self._resize_layers[index]
        postprocess_layer.index = self._graph_next_index + index

        self._hailo_nn.push_layer(postprocess_layer, [layer])

        output_succ = [
            output for output in self._hailo_nn.successors(postprocess_layer) if output.op == LayerType.output_layer
        ]
        if len(output_succ) == 1:
            output_succ[0].engine = PostprocessTarget.CPU

        return postprocess_layer

    def _add_resize_layer(self, index, layer, h_ratios, w_ratios):
        resize_layer = ResizeLayer()
        resize_layer.name = self._resize_layers[index]
        resize_layer.index = self._graph_next_index + index
        resize_layer.resize_method = self._resize_method
        resize_layer.resize_bilinear_pixels_mode = self._pixels_mode
        resize_layer.h_ratios = [h_ratios]
        resize_layer.w_ratios = [w_ratios]
        resize_layer.f_ratios = [1.0]
        resize_layer.input_shape = self._layers_to_shapes[layer]
        resize_layer.compilation_params["hw_layer_type_list"] = [self._hw_layer_type]

        layer.input_shape = self._layers_to_shapes[layer]
        self._hailo_nn.push_layer(resize_layer, [layer])

        return resize_layer

    def _get_resize_placement(self, resize_shapes):
        if isinstance(resize_shapes, dict):
            if all(
                self._hailo_nn.get_layer_by_name(layer_name).op == LayerType.input_layer for layer_name in resize_shapes
            ):
                # all layers are inputs
                return ResizePlacement.INPUT_LAYER

            real_output_names = [
                layer.name for layer in self._hailo_nn.get_real_output_layers(remove_non_neural_core_layers=False)
            ]
            if all(layer in real_output_names for layer in resize_shapes):
                # all layers are outputs
                return ResizePlacement.OUTPUT_LAYER

            raise ResizeLayersAdditionException("Resize layers can be applied only on input layer or output layer")

        return ResizePlacement.INPUT_LAYER

    def _get_layers_to_shapes(self, resize_shapes):
        if not isinstance(resize_shapes, dict):
            return {
                input_layer: [-1, *resize_shapes, input_layer.input_features]
                for input_layer in self._hailo_nn.get_input_layers()
            }

        layer_by_shape = {self._hailo_nn.get_layer_by_name(name): shape for name, shape in resize_shapes.items()}

        def get_features(layer):
            return (
                layer.input_features if self._resize_placement == ResizePlacement.INPUT_LAYER else layer.output_features
            )

        return {layer: [-1, *shape, get_features(layer)] for layer, shape in layer_by_shape.items()}
