import math
from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import MACDataQuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ResizeBilinearPixelsMode
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException
from hailo_model_optimization.acceleras.utils.resize_common import CustomBilinear


@dataclass
class ConvWeightsLossy(BaseWeightLossyElements):
    kernel: BaseLossyElement


class ResizeBilinearMacOp(BaseAtomicOp):
    RESIZE_RATIOS_THRESH = 0.005

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name,
        w_ratios=None,
        h_ratios=None,
        f_ratios=None,
        resize_bilinear_pixels_mode=None,
        compilation_params: dict = {},
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        self._w_ratios = w_ratios
        self._h_ratios = h_ratios
        self._f_ratios = f_ratios
        self._resize_bilinear_pixels_mode = ResizeBilinearPixelsMode(resize_bilinear_pixels_mode)
        self._compilation_params = compilation_params
        self.weight_lossy_elements = ConvWeightsLossy(kernel=IdentityElement(name=f"{self.full_name}/ie:conv_weight"))
        self.kernel_scale = 1
        self.pre_acc_shift = 0
        self.shift_delta = 0

    def _validate_resize_ratios_threshold(self, input_shape):
        # For some ratios we will lose precision by multiplying and new_height/new_width will have a
        # slight difference from the integer value
        cur_height = input_shape[1]
        cur_width = input_shape[2]
        for h, w in zip(self._h_ratios, self._w_ratios):
            candiate_new_height = cur_height * h
            candiate_new_width = cur_width * w

            new_height = int(np.round(candiate_new_height))
            new_width = int(np.round(candiate_new_width))
            if (np.abs(new_height - candiate_new_height) > self.RESIZE_RATIOS_THRESH) or (
                np.abs(new_width - candiate_new_width) > self.RESIZE_RATIOS_THRESH
            ):
                raise AccelerasException(
                    f"Resize requires integer dimensions. Recieved [{candiate_new_height}, {candiate_new_width}]",
                )
            cur_height = new_height
            cur_width = new_width

    def _build(self, input_shape):
        self._calculate_new_output_shapes(input_shape)
        self._validate_resize_ratios_threshold(input_shape)

    def _compute_output_shape(self, input_shape):
        self._calculate_new_output_shapes(input_shape)
        shape = [input_shape[0], self._new_height, self._new_width, input_shape[3]]
        self.kernel_grid = self.get_bilinear_weights(input_shape)
        return shape

    def _calculate_new_output_shapes(self, input_shape):
        self._new_height = int(np.round(np.prod(self._h_ratios, initial=input_shape[1])))
        self._new_width = int(np.round(np.prod(self._w_ratios, initial=input_shape[2])))

    def import_independent_params(self, params):
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.kernel_scale = 1 / 126.0

    def call_native(self, inputs, **kwargs):
        output_size = (self._new_height, self._new_width)
        align_corners = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.ALIGN_CORNERS
        half_pixels = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.HALF_PIXELS
        return CustomBilinear(inputs[0], output_size, align_corners, half_pixels)

    def call_hw_sim(self, inputs, **kwargs):
        q_inputs = [inputs[0] / (self.kernel_scale * 2**self.pre_acc_shift)]
        return self.call_native(q_inputs, **kwargs)

    def create_hw_params(self, **kwargs):
        self.kernel_scale = 1 / 126.0
        weight_bits = self.weight_lossy_elements.kernel.bits
        if weight_bits == 15:
            self.pre_acc_shift = 0
        else:
            self.pre_acc_shift = 1

    def enforce_encoding(self, training=False, **kwargs):
        if self._f_ratios[0] == 1 or self.input_scale_is_scalar():
            input_scale = self.input_scales[0]
        else:
            input_scale = tf.repeat(self.input_scales[0], repeats=int(self._f_ratios[0]))
        self.output_scale = input_scale * self.kernel_scale * 2**self.pre_acc_shift
        self.output_zero_point = self.input_zero_points[0] / (self.kernel_scale * 2**self.pre_acc_shift)

    def create_weight_quant_element(self, kernel_bits=8, signed=True):
        self.weight_lossy_elements = ConvWeightsLossy(
            kernel=MACDataQuantElement(bits=kernel_bits, signed=signed, name=f"{self.full_name}/qe:kernel"),
        )

    def export_independent_params(self):
        zp_kernel = 0
        return {
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "kernel_zero_point": np.array(zp_kernel, np.float32),
            "weight_bits": np.array(self.weight_lossy_elements.kernel.bits, np.float32),
        }

    def export_hw_params(self):
        w_type = np.int8 if self.weight_lossy_elements.kernel.bits <= 8 else np.int16
        return {
            "kernel": self.final_quantized_kernel.numpy().astype(w_type),
            "zp_kernel": np.array(0, dtype=np.int32),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
        }

    def export_quant_weights(self):
        kernel_q = self.final_quantized_kernel.numpy()
        return {
            "quant_kernel": np.float32(kernel_q),
        }

    @property
    def final_quantized_kernel(self):
        kernel_grid_divided_by_scale = tf.cast(self.kernel_grid / self.kernel_scale, tf.float32)
        return self.weight_lossy_elements.kernel(kernel_grid_divided_by_scale)

    def get_bilinear_weights(self, input_shape):
        align_corners = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.ALIGN_CORNERS
        half_pixels = self._resize_bilinear_pixels_mode == ResizeBilinearPixelsMode.HALF_PIXELS

        h_ratio = self._h_ratios[0]
        w_ratio = self._w_ratios[0]

        h_out = self._new_height
        w_out = self._new_width

        h_in = input_shape[1]
        w_in = input_shape[2]

        w_factor = 2 * w_out if not align_corners or w_out == 1 else 2 * (w_out - 1)
        h_factor = 2 * h_out if not align_corners or h_out == 1 else 2 * (h_out - 1)
        if half_pixels:
            y1, x1 = round((0.5 + 0.5 / h_ratio) * h_factor), round((0.5 + 0.5 / w_ratio) * w_factor)
            y2, x2 = round((h_in + 0.5 - 0.5 / h_ratio) * h_factor), round((w_in + 0.5 - 0.5 / w_ratio) * w_factor)
        else:
            y1, x1 = h_factor, w_factor
            if align_corners:
                y2, x2 = h_in * h_factor, w_in * w_factor
            else:
                y2, x2 = round((h_in + 1 - (1 / h_ratio)) * h_factor), round((w_in + 1 - (1 / w_ratio)) * w_factor)
        y_range = np.linspace(y1, y2, h_out)
        x_range = np.linspace(x1, x2, w_out)
        if half_pixels:
            if int(w_ratio) == w_ratio:
                x_in_corner = int(w_ratio / 2)
            elif (x2 - x1) != 0:
                x_in_corner = math.ceil((w_factor - x1) / (w_factor / w_ratio))
            else:
                x_in_corner = 0
            if int(h_ratio) == h_ratio:
                y_in_corner = int(h_ratio / 2)
            elif (y2 - y1) != 0:
                y_in_corner = math.ceil((h_factor - y1) / (h_factor / h_ratio))
            else:
                y_in_corner = 0
        else:
            x_in_corner, y_in_corner = 0, 0
        xx_in, yy_in = np.meshgrid(x_range, y_range)
        # Calculate coordinates and residuals for bilinear
        xx_in_floor = xx_in // w_factor
        yy_in_floor = yy_in // h_factor
        xx_in = xx_in / w_factor
        yy_in = yy_in / h_factor
        r_x = np.reshape((xx_in - xx_in_floor), (h_out, w_out, 1, 1))
        r_y = np.reshape((yy_in - yy_in_floor), (h_out, w_out, 1, 1))
        if half_pixels:
            if w_ratio % 2 != 0 and w_ratio == int(w_ratio):
                r_x[:, x_in_corner :: int(w_ratio), :, :] = 0
            if h_ratio % 2 != 0 and h_ratio == int(h_ratio):
                r_y[y_in_corner :: int(h_ratio), :, :, :] = 0
            if x_in_corner != 0:
                r_x[:, : int(x_in_corner), :, :] = 1
                r_x[:, int(-(x_in_corner)) :, :, :] = 0
            if y_in_corner != 0:
                r_y[: int(y_in_corner), :, :, :] = 1
                r_y[int(-(y_in_corner)) :, :, :, :] = 0
        elif not half_pixels and not align_corners:
            r_x[:, int(-(math.floor(w_ratio))) :, :, :] = 0
            r_y[int(-(math.floor(h_ratio))) :, :, :, :] = 0
        one_m_rx = 1 - r_x
        one_m_ry = 1 - r_y

        kernel = np.concatenate((one_m_rx * one_m_ry, r_x * one_m_ry, one_m_rx * r_y, r_x * r_y), axis=2)
        return kernel
