import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    DefuseType,
    HnStage,
    LayerType,
    ResizeBilinearPixelsMode,
    ResizeBilinearPixelsModes,
    ResizeMethod,
    ResizeMethods,
)
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ResizeLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True
    _DEFAULT_F_RATIO = 1.0

    def __init__(self):
        super().__init__()
        self._op = LayerType.resize
        self._h_ratios = []
        self._w_ratios = []
        self._f_ratios = []
        self._d_ratios = []
        self._method = None
        self._input_disparity = None
        self._h_sizes = None
        self._w_sizes = None
        self._d_sizes = None
        self._upscale_factors = None
        self._resize_bilinear_pixels_mode = ResizeBilinearPixelsMode.disabled
        self._activation = ActivationType.linear

        # Contains pointers to HN layers from which the shapes need to be taken as resize_sizes
        self._resize_layers = None

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        resize_method,
        input_disparity=1,
        output_shapes=None,
        h_sizes=None,
        w_sizes=None,
        d_sizes=None,
        upscale_factors=None,
        resize_bilinear_pixels_mode=ResizeBilinearPixelsMode.disabled,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.resize_method = resize_method
        layer._resize_bilinear_pixels_mode = resize_bilinear_pixels_mode
        hw_type_list = ["ppu"] if (layer.is_bilinear_align_corners_not_quantized) else ["lcu"]
        layer.set_compilation_params(hw_layer_type_list=hw_type_list)
        layer._input_disparity = input_disparity
        layer._h_sizes = h_sizes
        layer._w_sizes = w_sizes
        layer._d_sizes = d_sizes
        layer._upscale_factors = upscale_factors
        return layer

    @property
    def activation(self):
        return self._activation

    @property
    def is_bilinear_half_pixels(self):
        return (
            self._method == ResizeMethod.bilinear
            and self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.half_pixels
        )

    @property
    def is_bilinear_align_corners(self):
        return (
            self._method == ResizeMethod.bilinear
            and self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.align_corners
        )

    @property
    def is_bilinear_align_corners_not_quantized(self):
        if len(self._compilation_params.get("hw_layer_type_list")) == 0:
            return False
        else:
            return (
                self.is_bilinear_align_corners and self._compilation_params.get("hw_layer_type_list")[0].value == "ppu"
            )

    @property
    def is_nearest_half_pixels(self):
        return (
            self._method == ResizeMethod.nearest_neighbor
            and self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.half_pixels
        )

    @property
    def should_update_resize_ratios(self):
        return not self.output_shape and (self._h_sizes or self._d_sizes or self._d_sizes or self._upscale_factors)

    def update_output_shapes(self, **kwargs):
        if self._resize_layers:
            if (not self._resize_layers[0] or self._resize_layers[0].output_shape) and (
                not self._resize_layers[1] or self._resize_layers[1].output_shape
            ):
                self._h_sizes = (
                    self._resize_layers[0].output_shape[1] if self._resize_layers[0] else self.input_shape[1]
                )
                self._w_sizes = (
                    self._resize_layers[1].output_shape[2] if self._resize_layers[1] else self.input_shape[2]
                )
                self._resize_layers = None
            else:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} is missing information from one/more of the layers "
                    f"it depends on for output shape calculation.",
                )

            self._update_resize_ratios()

        super().update_output_shapes()

        if self.output_shapes and -1 not in self.output_shapes[0][1:] and kwargs["hn_stage"] != HnStage.PRE_FUSED:
            type(self)._validate(self)

    @property
    def output_shapes(self):
        if self._resize_layers:
            self._output_shapes = []
            self.update_output_shapes()
        return self._output_shapes

    @output_shapes.setter
    def output_shapes(self, output_shapes):
        super(ResizeLayer, self.__class__).output_shapes.fset(self, output_shapes)

    def _calc_output_shape(self):
        if self.defuse_type is DefuseType.spatial_w:
            assert "defuse_output_shapes" in self.defuse_params
            assert self.defuse_params["defuse_output_shapes"] is not None
            assert len(self.defuse_params["defuse_output_shapes"]) == 1
            assert len(self.defuse_params["defuse_output_shapes"][0]) == 3
            return [-1] + self.defuse_params["defuse_output_shapes"][0]

        input_height = self.input_shape[1] if len(self.input_shape) == 4 else 1
        input_width = self.input_shape[2] if len(self.input_shape) == 4 else 1

        if self.defuse_type == DefuseType.resize_transpose:
            input_features = self.defuse_features
            input_width = self.input_shape[3]
        elif "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.resize,
            DefuseType.spatial_w,
            DefuseType.nv,
            DefuseType.i420,
        ]:
            input_features = self.defuse_features
        else:
            input_features = self.input_shape[3] if len(self.input_shape) == 4 else self.input_shape[1]

        return [
            -1,
            int(np.round(np.prod(self._h_ratios, initial=input_height))),
            int(np.round(np.prod(self._w_ratios, initial=input_width))),
            int(np.round(np.prod(np.array(self.f_ratios) * np.array(self.d_ratios), initial=input_features))),
        ]

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["resize_h_ratio_list"] = list(self._h_ratios)
        result["params"]["resize_w_ratio_list"] = list(self._w_ratios)
        result["params"]["resize_f_ratio_list"] = (
            list(self._f_ratios) if self._f_ratios else [type(self)._DEFAULT_F_RATIO]
        )
        result["params"]["method"] = self._method.value
        result["params"]["resize_bilinear_pixels_mode"] = self._resize_bilinear_pixels_mode.value
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_RESIZE
        node.resize_h_ratio_list.extend(self._h_ratios)
        node.resize_w_ratio_list.extend(self._w_ratios)
        node.resize_f_ratio_list.extend(self._f_ratios)
        node.resize_method = pb_wrapper.RESIZE_METHOD_TYPE_TO_PB[self._method]
        node.resize_bilinear_pixels_mode = pb_wrapper.RESIZE_BILINEAR_PIXELS_MODE_TYPE_TO_PB[
            self._resize_bilinear_pixels_mode
        ]
        return node

    @property
    def h_ratios(self):
        return self._h_ratios

    @h_ratios.setter
    def h_ratios(self, h_ratios):
        self._h_ratios = h_ratios

    @property
    def w_ratios(self):
        return self._w_ratios

    @w_ratios.setter
    def w_ratios(self, w_ratios):
        self._w_ratios = w_ratios

    @property
    def f_ratios(self):
        return self._f_ratios

    @f_ratios.setter
    def f_ratios(self, f_ratios):
        self._f_ratios = f_ratios

    @property
    def d_ratios(self):
        return self._d_ratios if self._d_ratios else [1.0]

    @property
    def resize_method(self):
        return self._method

    @resize_method.setter
    def resize_method(self, resize_method):
        self._method = resize_method

    @property
    def resize_layers(self):
        return self._resize_layers

    @resize_layers.setter
    def resize_layers(self, resize_layers):
        self._resize_layers = resize_layers

    @property
    def resize_bilinear_pixels_mode(self):
        return self._resize_bilinear_pixels_mode

    @resize_bilinear_pixels_mode.setter
    def resize_bilinear_pixels_mode(self, resize_bilinear_pixels_mode):
        self._resize_bilinear_pixels_mode = resize_bilinear_pixels_mode

    @property
    def input_disparity(self):
        return self._input_disparity

    def _update_resize_ratios(self):
        if self.should_update_resize_ratios and not self._h_ratios and not self._w_ratios and not self._resize_layers:
            h_ratio = float(self._h_sizes) / float(self.input_shape[1]) if self._h_sizes else None
            w_ratio = float(self._w_sizes) / float(self.input_shape[2]) if self._w_sizes else None
            d_ratio = float(self._d_sizes) / float(self._input_disparity) if self._d_sizes else 1.0

            if None in [h_ratio, w_ratio] and self._upscale_factors:
                h_ratio = self._upscale_factors[0]
                w_ratio = self._upscale_factors[1]

            if not h_ratio or not w_ratio:
                raise UnsupportedModelError(f"{self.full_name_msg} was created without output shapes or resize values.")
            self._h_ratios.append(h_ratio)
            self._w_ratios.append(w_ratio)
            # default upscale factor for features before setting output shape
            self._f_ratios.append(type(self)._DEFAULT_F_RATIO)
            self._d_ratios.append(d_ratio)

    @staticmethod
    def reshape_inputs(input_shapes):
        new_shapes = []
        for shape in input_shapes:
            if len(shape) == 2:
                new_shapes.append([shape[0], 1, 1, shape[1]])
            else:
                new_shapes.append(shape)
        return new_shapes

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(ResizeLayer.reshape_inputs(input_shapes), validate)
        self._update_resize_ratios()

        if self._upscale_factors and not self._h_ratios and not self._w_ratios:
            self._h_ratios.append(float(self._upscale_factors[0]))
            self._w_ratios.append(float(self._upscale_factors[1]))
            self._f_ratios.append(type(self)._DEFAULT_F_RATIO)

        elif self.output_shape is not None and not self._h_ratios and not self._w_ratios:
            h_ratio = float(self.output_shape[1]) / float(self.input_shape[1])
            w_ratio = float(self.output_shape[2]) / float(self.input_shape[2])
            f_ratio = float(self.output_shape[3]) / float(self.input_shape[3])

            # support features ratio only for nearest neighbor of the form (h,w,1)->(h,w,c)
            if f_ratio != 1.0:
                if self._method != ResizeMethod.nearest_neighbor:
                    raise UnsupportedModelError(
                        f"{self.full_name_msg} of method bilinear does not support broadcast "
                        f"over features, ratio={f_ratio}",
                    )
                elif self.input_shape[3] != 1:
                    raise UnsupportedModelError(
                        f"{self.full_name_msg} of method nearest_neighbor that broadcast over "
                        f"features only works with ratio equal to output features (resize 1 to "
                        f"output_features), ratio={f_ratio}, "
                        f"input_features={self.input_shape[3]}, "
                        f"output_features={self.output_shape[3]}",
                    )

            self._h_ratios.append(h_ratio)
            self._w_ratios.append(w_ratio)
            self._f_ratios.append(f_ratio)

    @classmethod
    def _validate(cls, resize_layer):
        if len(resize_layer.h_ratios) != len(resize_layer.w_ratios):
            raise UnsupportedModelError(
                f"Different number of height and width ratios (# of "
                f"heights={len(resize_layer.h_ratios)}, # of "
                f"widths={len(resize_layer.w_ratios)}]) at {resize_layer.full_name_msg}",
            )

        if len(resize_layer.h_ratios) != len(resize_layer.compilation_params["hw_layer_type_list"]):
            raise UnsupportedModelError(
                f'Different number of ratios and hw_layer_types (# of '
                f'ratios={len(resize_layer.h_ratios)}, # of '
                f'hw_layer_types={len(resize_layer.compilation_params["hw_layer_type_list"])}])'
                f' at {resize_layer.full_name_msg}',
            )

        if len(resize_layer.f_ratios) != 1:
            raise UnsupportedModelError(
                f"{resize_layer.full_name_msg} only supports one value for features broadcast, "
                f"got {resize_layer.f_ratios}",
            )

        h_ratio = resize_layer.h_ratios[0]
        w_ratio = resize_layer.w_ratios[0]
        f_ratio = resize_layer.f_ratios[0]
        if f_ratio != 1:
            if resize_layer._method != ResizeMethod.nearest_neighbor:
                raise UnsupportedModelError(
                    f"{resize_layer.full_name_msg} does not support broadcast over features, ratio={f_ratio}",
                )
            # note: the reason for this validation, is that even-though we calculated an outputs shape based on the
            # ratio and not input/output shapes, it's still incorrect if the broadcast on the features isn't of the
            # form 1->f
            elif (
                resize_layer.output_shape
                and f_ratio != resize_layer.output_shape[3]
                and resize_layer.output_shape[3] % f_ratio != 0
            ):
                raise UnsupportedModelError(
                    f"{resize_layer.full_name_msg} of method nearest_neighbor that broadcast "
                    f"over features only works with ratio equal to output features (resize 1 to"
                    f" output_features), ratio={f_ratio}, "
                    f"output_features={resize_layer.output_shape[3]}",
                )
            elif h_ratio != 1 or w_ratio != 1:
                raise UnsupportedModelError(
                    f"Resize nearest_neighbor {resize_layer.name} that broadcast over features "
                    f"only works with spatial ratios equal to 1 (i.e. no resize on "
                    f"height/width). h_ratio={h_ratio}, w_ratio={w_ratio}",
                )

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer._method = ResizeMethods[hn["params"]["method"]]
        layer._h_ratios = hn["params"]["resize_h_ratio_list"]
        layer._w_ratios = hn["params"]["resize_w_ratio_list"]
        layer._f_ratios = hn["params"].get("resize_f_ratio_list", [cls._DEFAULT_F_RATIO])
        if layer._method == ResizeMethod.bilinear:
            default_pixel_mode = ResizeBilinearPixelsMode.align_corners.value
        elif layer._method == ResizeMethod.nearest_neighbor:
            default_pixel_mode = ResizeBilinearPixelsMode.disabled.value
        default_hw_layer = ["lcu"]

        pixel_mode = hn["params"].get("resize_bilinear_pixels_mode", default_pixel_mode)
        if "compilation_params" in hn:
            hw_layer = hn["compilation_params"].get("hw_layer_type_list", default_hw_layer)
        else:
            hw_layer = default_hw_layer
        layer._resize_bilinear_pixels_mode = ResizeBilinearPixelsModes[pixel_mode]

        # TODO: A bit ugly, when a Layer is created from a hn, the compilation params are set to the default values
        # that don't include 'hw_layer_type_list'
        layer.set_compilation_params(hw_layer_type_list=hw_layer)
        cls._validate(layer)
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._h_ratios = pb.resize_h_ratio_list
        layer._w_ratios = pb.resize_w_ratio_list
        layer._f_ratios = pb.resize_f_ratio_list
        layer._method = pb_wrapper.RESIZE_METHOD_PB_TO_TYPE[pb.resize_method]

        if layer._method == ResizeMethod.bilinear:
            layer._resize_bilinear_pixels_mode = ResizeBilinearPixelsMode.align_corners
        elif layer._method == ResizeMethod.nearest_neighbor:
            layer._resize_bilinear_pixels_mode = ResizeBilinearPixelsMode.disabled

        if pb.HasField("resize_bilinear_pixels_mode"):
            layer._resize_bilinear_pixels_mode = pb_wrapper.RESIZE_BILINEAR_PIXELS_MODE_PB_TO_TYPE[
                pb.resize_bilinear_pixels_mode
            ]
        hw_type_list = ["lcu"] if not layer.is_bilinear_align_corners_not_quantized else ["ppu"]
        layer.set_compilation_params(hw_layer_type_list=hw_type_list)
        cls._validate(layer)

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer._h_ratios = old_layer.h_ratios
        layer._w_ratios = old_layer.w_ratios
        layer._f_ratios = old_layer.f_ratios
        layer._method = old_layer.resize_method
        layer._resize_bilinear_pixels_mode = old_layer.resize_bilinear_pixels_mode
        return layer

    @property
    def requires_quantized_weights(self):
        if self._method.value == "bilinear" and not self.is_bilinear_align_corners_not_quantized:
            return True

        return type(self)._REQUIRES_QUANTIZED_WEIGHTS

    def _get_equiv_handler(self):
        if (np.array(self._f_ratios) != 1).any():
            handler = LayerHandlerType.unsupported
        else:
            handler = LayerHandlerType.transparent
        return handler

    def get_equalization_handler_type(self, predecessor=None):
        handler = self._get_equiv_handler()
        return EquivClassification(handler, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        handler = self._get_equiv_handler()
        return EquivClassification(handler, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        handler = self._get_equiv_handler()
        return EquivClassification(handler, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def kernel_short_description(self):
        return f" ({self._method.value})"

    def is_zippable(self, other):
        """Allow zipping two resize layers as long as they share the same parameters"""
        if self.h_ratios != other.h_ratios:
            return False
        if self.w_ratios != other.w_ratios:
            return False
        if self.f_ratios != other.f_ratios:
            return False
        if self.resize_method != other.resize_method:
            return False
        if self.resize_bilinear_pixels_mode != other.resize_bilinear_pixels_mode:
            return False
        return super().is_zippable(other)
