import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ResizeBilinearPixelsMode
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException
from hailo_model_optimization.acceleras.utils.resize_common import CustomBilinear


class ResizeBilinearPpuOp(BaseAtomicOp):
    """
    Produces an AtomicOp that contains one resize operation.
    The purpose of this Op is to simulate the operation of the concat layer.
    Currently only supports fully native mode.
    """

    # These attributes is the inputs shape of the tensor in run time

    RESIZE_RATIOS_THRESH = 0.005

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        h_ratios: list = None,
        w_ratios: list = None,
        f_ratios: list = None,
        resize_bilinear_pixels_mode: ResizeBilinearPixelsMode = ResizeBilinearPixelsMode.DISABLED,
        compilation_params: dict = {},
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        if h_ratios is None:
            h_ratios = []
        if w_ratios is None:
            w_ratios = []
        if f_ratios is None:
            f_ratios = []
        self._h_ratios = h_ratios
        self._w_ratios = w_ratios
        self._f_ratios = f_ratios
        self._resize_bilinear_pixels_mode = resize_bilinear_pixels_mode
        self._compilation_params = compilation_params

    def _validate_resize_ratios_threshold(self, input_shape):
        # For some ratios we will lose precision by multiplying and new_height/new_width will have a
        # slight difference from the integer value
        cur_height = input_shape[1]
        cur_width = input_shape[2]
        for h, w in zip(self._h_ratios, self._w_ratios):
            candiate_new_height = cur_height * h
            candiate_new_width = cur_width * w

            new_height = int(np.round(candiate_new_height))
            new_width = int(np.round(candiate_new_width))
            if (np.abs(new_height - candiate_new_height) > self.RESIZE_RATIOS_THRESH) or (
                np.abs(new_width - candiate_new_width) > self.RESIZE_RATIOS_THRESH
            ):
                raise AccelerasException(
                    f"Resize requires integer dimensions. Recieved [{candiate_new_height}, {candiate_new_width}]",
                )
            cur_height = new_height
            cur_width = new_width

    def _build(self, input_shape):
        self._new_height = int(np.round(np.prod(self._h_ratios, initial=input_shape[1])))
        self._new_width = int(np.round(np.prod(self._w_ratios, initial=input_shape[2])))
        self._validate_resize_ratios_threshold(input_shape)

    def _compute_output_shape(self, input_shape):
        self._calculate_new_output_shapes(input_shape)
        shape = [input_shape[0], self._new_height, self._new_width, input_shape[3]]
        return shape

    def _calculate_new_output_shapes(self, input_shape):
        self._new_height = int(np.round(np.prod(self._h_ratios, initial=input_shape[1])))
        self._new_width = int(np.round(np.prod(self._w_ratios, initial=input_shape[2])))

    def call_native(self, inputs, **kwargs):
        output_size = (self._new_height, self._new_width)
        align_corners = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.ALIGN_CORNERS
        half_pixels = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.HALF_PIXELS
        return CustomBilinear(inputs[0], output_size, align_corners, half_pixels)

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def create_hw_params(self, **kwargs):
        pass

    def enforce_encoding(self, training=False, **kwargs):
        if self._f_ratios[0] == 1 or self.input_scale_is_scalar():
            self.output_scale = self.input_scales[0]
        else:
            self.output_scale = tf.repeat(self.input_scales[0], repeats=int(self._f_ratios[0]))
        self.output_zero_point = self.input_zero_points[0]

    def create_weight_quant_element(self, **kwargs):
        pass
