import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ResizeBilinearPixelsMode
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException
from hailo_model_optimization.acceleras.utils.resize_common import CustomNearest


class ResizeNearestNeighborOp(BaseAtomicOp):
    """
    Produces an AtomicOp that contains one resize operation.
    The purpose of this Op is to simulate the operation of the concat layer.
    Currently only supports fully native mode.
    """

    RESIZE_RATIOS_THRESH = 0.005

    num_inputs = 1
    num_outputs = 1

    # These attributes is the inputs shape of the tensor in run time

    def __init__(
        self,
        name,
        h_ratios: list = None,
        w_ratios: list = None,
        f_ratios: list = None,
        resize_bilinear_pixels_mode: ResizeBilinearPixelsMode = ResizeBilinearPixelsMode.DISABLED,
        compilation_params: dict = {},
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        if h_ratios is None:
            h_ratios = []
        if w_ratios is None:
            w_ratios = []
        if f_ratios is None:
            f_ratios = []
        self._h_ratios = h_ratios
        self._w_ratios = w_ratios
        self._f_ratios = f_ratios
        self._resize_bilinear_pixels_mode = resize_bilinear_pixels_mode
        self._compilation_params = compilation_params

    def _validate_resize_ratios_threshold(self, input_shape):
        # For some ratios we will lose precision by multiplying and new_height/new_width will have a
        # slight difference from the integer value
        cur_height = input_shape[1]
        cur_width = input_shape[2]
        for h, w in zip(self._h_ratios, self._w_ratios):
            candiate_new_height = cur_height * h
            candiate_new_width = cur_width * w

            new_height = int(np.round(candiate_new_height))
            new_width = int(np.round(candiate_new_width))
            if (np.abs(new_height - candiate_new_height) > self.RESIZE_RATIOS_THRESH) or (
                np.abs(new_width - candiate_new_width) > self.RESIZE_RATIOS_THRESH
            ):
                raise AccelerasException(
                    f"Resize requires integer dimensions. Recieved [{candiate_new_height}, {candiate_new_width}]",
                )
            cur_height = new_height
            cur_width = new_width

    def _build(self, input_shape):
        self._new_height = int(np.round(np.prod(self._h_ratios, initial=input_shape[1])))
        self._new_width = int(np.round(np.prod(self._w_ratios, initial=input_shape[2])))
        self._validate_resize_ratios_threshold(input_shape)

    def _compute_output_shape(self, input_shape):
        self._calculate_new_output_shapes(input_shape)
        shape = [input_shape[0], self._new_height, self._new_width, self._new_feature]
        return shape

    def _calculate_new_output_shapes(self, input_shape):
        self._new_height = int(np.round(np.prod(self._h_ratios, initial=input_shape[1])))
        self._new_width = int(np.round(np.prod(self._w_ratios, initial=input_shape[2])))
        self._new_feature = int(np.round(np.prod(self._f_ratios, initial=input_shape[3])))

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        new_features = int(self._f_ratios[0])
        half_pixels = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.HALF_PIXELS
        align_corners = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.ALIGN_CORNERS
        cur_height = inp.shape[1]
        cur_width = inp.shape[2]
        cur_op = inp
        if (
            not half_pixels
            and not align_corners
            and all(int(h) == h for h in self._h_ratios)
            and all(int(w) == w for w in self._w_ratios)
        ):
            # NOTE: This case is just broadcast. The common case in yolo architecture
            for h in self._h_ratios:
                cur_op = tf.repeat(cur_op, repeats=int(h), axis=1)
            for w in self._w_ratios:
                cur_op = tf.repeat(cur_op, repeats=int(w), axis=2)
        elif half_pixels and not align_corners:
            # NOTE: Standard resize with nearest neighbor interpolation, where the ratios are not integers.
            for h, w in zip(self._h_ratios, self._w_ratios):
                new_height = int(np.round(cur_op.shape[1] * h))
                new_width = int(np.round(cur_op.shape[2] * w))
                cur_op = tf.image.resize(cur_op, [new_height, new_width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        else:
            # NOTE: This case isn't very common, most cases should be handled by the above case.
            # The torch doesn't even support this case. only tf parser
            # (probably for legacy models, since tf itself doesn't support this case anymore)
            for h, w in zip(self._h_ratios, self._w_ratios):
                new_height = int(np.round(cur_height * h))
                new_width = int(np.round(cur_width * w))
                cur_op = CustomNearest(cur_op, [new_height, new_width], align_corners, half_pixels)
                cur_height = cur_op.shape[1]
                cur_width = cur_op.shape[2]
        if new_features > 1:
            tmp = tf.repeat(cur_op, repeats=new_features, axis=3)
            # After repeat, the shape is unknown, there we reshape for explicit shape specification (for support of graph mode)
            cur_op = tf.reshape(tmp, [-1, cur_op.shape[1], cur_op.shape[2], cur_op.shape[3] * new_features])
        return cur_op

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def create_hw_params(self, **kwargs):
        pass

    def enforce_encoding(self, training=False, **kwargs):
        if self._f_ratios[0] == 1 or self.input_scale_is_scalar():
            self.output_scale = self.input_scales[0]
        else:
            self.output_scale = tf.repeat(self.input_scales[0], repeats=int(self._f_ratios[0]))
        self.output_zero_point = self.input_zero_points[0]

    def create_weight_quant_element(self, **kwargs):
        pass

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # Compute output_scale
        if self._f_ratios[0] == 1:
            enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        else:
            enc.callback(
                f"{self.full_name}/output_scale:0",
                f"{self.full_name}/input_scale:0",
                tf.repeat,
                callback_name="tf.repeat",
                outs_shape=(self.output_shape[-1],),
                repeats=int(self._f_ratios[0]),
            )

        # Compute output_zero_point
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
