from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.resize_bilinear_ppu_op import ResizeBilinearPpuOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
    ResizeBilinearPixelsMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasInitializationError,
    AccelerasUnsupportedError,
)


class HailoResizeBilinearPpu(BaseHailoSingleAtomic):
    """Represents `resize` layer in the hn"""

    _hn_type = LayerType.RESIZE
    OP_NAME = "resize_bilinear_ppu_op"

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_initialization,
    }

    SUPPORTED_QUANTIZATION_GROUPS = False

    _DEFAULT_F_RATIO = 1.0

    def __init__(
        self,
        name: str,
        h_ratios: list = None,
        w_ratios: list = None,
        f_ratios: list = None,
        resize_bilinear_pixels_mode: Union[str, ResizeBilinearPixelsMode] = "disabled",
        compilation_params: dict = {},
        logger=None,
        **kwargs,
    ):
        if h_ratios is None:
            h_ratios = []
        if w_ratios is None:
            w_ratios = []
        if f_ratios is None:
            f_ratios = []
        self._h_ratios = h_ratios
        self._w_ratios = w_ratios
        self._f_ratios = f_ratios
        bilinear_pixel_mode = ResizeBilinearPixelsMode(resize_bilinear_pixels_mode)
        atomic_resize = ResizeBilinearPpuOp(
            f"{name}/{self.OP_NAME}",
            h_ratios=h_ratios,
            w_ratios=w_ratios,
            f_ratios=f_ratios,
            resize_bilinear_pixels_mode=bilinear_pixel_mode,
            compilation_params=compilation_params,
            logger=logger,
        )
        super().__init__(name=name, core_op=atomic_resize, logger=logger, **kwargs)

    def _validate_ratios(self):
        resize_op = self.atomic_op
        if len(resize_op._h_ratios) != len(resize_op._w_ratios):
            raise AccelerasInitializationError(
                f"Different number of height and width ratios (# of heights={len(resize_op._h_ratios)}, # of widths={len(resize_op._w_ratios)}])",
            )

        if len(resize_op._h_ratios) != len(resize_op._compilation_params["hw_layer_type_list"]):
            raise AccelerasInitializationError(
                "Different number of ratios and hw_layer_types (# of ratios={}, # of hw_layer_types={}])".format(
                    len(resize_op._h_ratios),
                    len(resize_op._compilation_params["hw_layer_type_list"]),
                ),
            )

        if len(resize_op._f_ratios) != 1:
            raise AccelerasInitializationError(
                f"Resize op {resize_op.full_name} only supports one value for features broadcast, got {resize_op._f_ratios}",
            )

        f_ratio = resize_op._f_ratios[0]
        if f_ratio != 1:
            raise AccelerasInitializationError(
                f"Resize bilinear {resize_op.full_name} does not support broadcast over features, ratio={f_ratio}",
            )

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        h_ratios = hn_element["params"]["resize_h_ratio_list"]
        w_ratios = hn_element["params"]["resize_w_ratio_list"]
        f_ratios = hn_element["params"].get("resize_f_ratio_list", [cls._DEFAULT_F_RATIO])
        resize_bilinear_pixels_mode = hn_element["params"].get(
            "resize_bilinear_pixels_mode",
            ResizeBilinearPixelsMode.ALIGN_CORNERS,
        )
        if "activation" in hn_element["params"]:
            raise AccelerasUnsupportedError("Acceleras currently doesn't supported resize with data activation")
        compilation_params = hn_element["compilation_params"]
        layer = cls(
            name=lname,
            h_ratios=h_ratios,
            w_ratios=w_ratios,
            f_ratios=f_ratios,
            resize_bilinear_pixels_mode=resize_bilinear_pixels_mode,
            compilation_params=compilation_params,
            logger=logger,
        )
        layer._validate_ratios()
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        resize_op = self.atomic_op
        if (np.array(resize_op._f_ratios) != 1).any():
            handler = LayerHandlerType.unsupported
        else:
            handler = LayerHandlerType.transparent
        return EquivClassification(handler, is_source=False)

    @property
    def atomic_ops(self):
        return [self.atomic_op]

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self.atomic_op.create_hw_params()

    def enforce_io_encoding(self, training=False, **kwargs):
        self.atomic_op.enforce_encoding()

    def _export_weights(self):
        return dict()

    def is_jit_compile_supported(self, training=False):
        """jit_compile is not supported because tf.compat.v1.resize grads do not work with jit_compile"""
        return False
