from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.resize_bilinear_mac_op import ResizeBilinearMacOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
    ResizeBilinearPixelsMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasInitializationError,
    AccelerasUnsupportedError,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


#  TODO Need to check export qnpz
class HailoResizeBilinearMac(BaseHailoLayer):
    """Represents `resize` layer in the hn"""

    _hn_type = LayerType.RESIZE

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _DEFAULT_F_RATIO = 1.0

    def __init__(
        self,
        name: str,
        h_ratios: list = None,
        w_ratios: list = None,
        f_ratios: list = None,
        resize_bilinear_pixels_mode: Union[str, ResizeBilinearPixelsMode] = "disabled",
        compilation_params: dict = {},
        bias_initializer=None,
        logger=None,
        **kwargs,
    ):
        if h_ratios is None:
            h_ratios = []
        if w_ratios is None:
            w_ratios = []
        if f_ratios is None:
            f_ratios = []
        self.bias_add_op = AddBiasOp(
            f"{name}/bias_add_op",
            bias_initializer,
            trainable=False,
            logger=logger,
            is_correctable=False,
        )
        self.act_op = ActivationOp(f"{name}/act_op", activation="linear", logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        self.resize_bilinear_mac_op = ResizeBilinearMacOp(
            f"{name}/resize_bilinear_mac_op",
            w_ratios=w_ratios,
            h_ratios=h_ratios,
            f_ratios=f_ratios,
            resize_bilinear_pixels_mode=resize_bilinear_pixels_mode,
            compilation_params=compilation_params,
            logger=logger,
        )
        super().__init__(name=name, logger=logger, **kwargs)

    def _validate_ratios(self):
        resize_op = self.resize_bilinear_mac_op
        if len(resize_op._h_ratios) != len(resize_op._w_ratios):
            raise AccelerasInitializationError(
                f"Different number of height and width ratios (# of heights={len(resize_op._h_ratios)}, # of widths={len(resize_op._w_ratios)}])",
            )

        if len(resize_op._h_ratios) != len(resize_op._compilation_params["hw_layer_type_list"]):
            raise AccelerasInitializationError(
                "Different number of ratios and hw_layer_types (# of ratios={}, # of hw_layer_types={}])".format(
                    len(resize_op._h_ratios),
                    len(resize_op._compilation_params["hw_layer_type_list"]),
                ),
            )

        if len(resize_op._f_ratios) != 1:
            raise AccelerasInitializationError(
                f"Resize op {resize_op.full_name} only supports one value for features broadcast, got {resize_op._f_ratios}",
            )

        f_ratio = resize_op._f_ratios[0]
        if f_ratio != 1:
            raise AccelerasInitializationError(
                f"Resize bilinear {resize_op.full_name} does not support broadcast over features, ratio={f_ratio}",
            )

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        h_ratios = hn_element["params"]["resize_h_ratio_list"]
        w_ratios = hn_element["params"]["resize_w_ratio_list"]
        f_ratios = hn_element["params"].get("resize_f_ratio_list", [cls._DEFAULT_F_RATIO])
        resize_bilinear_pixels_mode = hn_element["params"].get(
            "resize_bilinear_pixels_mode",
            ResizeBilinearPixelsMode.ALIGN_CORNERS,
        )
        if "activation" in hn_element["params"]:
            raise AccelerasUnsupportedError("Acceleras currently doesn't supported resize with data activation")
        compilation_params = hn_element["compilation_params"]
        layer = cls(
            name=lname,
            h_ratios=h_ratios,
            w_ratios=w_ratios,
            f_ratios=f_ratios,
            resize_bilinear_pixels_mode=resize_bilinear_pixels_mode,
            compilation_params=compilation_params,
            logger=logger,
        )
        layer._validate_ratios()
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        resize_op = self.resize_bilinear_mac_op
        if (np.array(resize_op._f_ratios) != 1).any():
            handler = LayerHandlerType.unsupported
        else:
            handler = LayerHandlerType.transparent
        return EquivClassification(handler, is_source=False)

    @property
    def atomic_ops(self):
        return [self.resize_bilinear_mac_op, self.bias_add_op, self.act_op, self.output_op]

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)
        self.resize_bilinear_mac_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_add_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        self.resize_bilinear_mac_op.create_hw_params()
        self.resize_bilinear_mac_op.enforce_encoding()
        self.bias_add_op.pre_acc_shift = self.resize_bilinear_mac_op.pre_acc_shift
        self.act_op.create_hw_params(self.resize_bilinear_mac_op.output_scale, optimization_target, nudging=False)
        self.enforce_internal_encoding()
        self.bias_add_op.create_hw_params()

    def _export_weights(self):
        return {"bias": self.bias_add_op.export_weights()}

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        Note that as we set the avgpool op weight to be an exact power of two, we don't have to set the apu encodings.
        """
        self._enforce_output_encoding()
        self.resize_bilinear_mac_op.enforce_encoding()
        self.bias_add_op.input_scales = [self.resize_bilinear_mac_op.output_scale]
        self.bias_add_op.output_scale = self.bias_add_op.input_scales[0]
        self.bias_add_op.input_zero_points = [self.resize_bilinear_mac_op.output_zero_point]
        self.bias_add_op.enforce_encoding()
        self.act_op.input_scales = [self.bias_add_op.output_scale]
        self.act_op.input_zero_points = [self.bias_add_op.output_zero_point]
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        self.set_output_scale(self.input_scale, 0)

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.resize_bilinear_mac_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.resize_bilinear_mac_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.resize_bilinear_mac_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        if precision_mode not in self.SUPPORTED_PRECISION_MODE:
            return False
        if bias_mode not in self.SUPPORTED_BIAS_MODE:
            return False
        return True

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    def _load_activation(self, layer_params):
        self.act_op.import_weights(layer_params)

    # TODO check this thing
    def import_weights(self, layer_params: LayerParams):
        self._load_activation(layer_params={})
        if "bias" in layer_params.keys():
            self.bias_add_op.import_weights(bias=layer_params["bias"])

    def is_jit_compile_supported(self, training=False):
        """jit_compile is not supported because tf.compat.v1.resize grads do not work with jit_compile"""
        return False
