from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import get_tf_same_padding
from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.avgpool_op import AvgPoolOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PaddingType,
    PrecisionMode,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasNumerizationError,
    AccelerasValueError,
    InvalidInputShape,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams, get_hn_padding
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoAvgPool(BaseHailoLayer):
    """
    Hailo's avgpool layer. Based on AvgPoolOp as core operation
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w4,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w4_a8,
        PrecisionMode.a8_w4_a16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _output_scale_scalar_dof: float
    _hn_type = LayerType.AVGPOOL

    def __init__(
        self,
        name: str,
        pool_size=(2, 2),
        strides=None,
        padding: Union[str, PaddingType] = "VALID",
        stride_align: Union[str, StrideAlignType] = "NW",
        activation: Union[str, callable, ActivationType] = "linear",
        is_max_width_avgpool=False,
        bias_initializer=None,
        trainable=False,
        logger=None,
        **kwargs,
    ):
        strides = strides or pool_size
        avgpool_op = AvgPoolOp(
            f"{name}/avgpool_op",
            kernel_size=pool_size,
            strides=strides,
            padding=padding,
            stride_align=stride_align,
            logger=logger,
        )

        self.avgpool_op = avgpool_op
        self.bias_add_op = AddBiasOp(f"{name}/bias_add_op", bias_initializer, trainable=trainable, logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)

        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
        self.output_scale_scalar_dof = 1
        self._is_max_width_avgpool = is_max_width_avgpool
        if is_max_width_avgpool:
            # Both global avgpool and tiled avgpool are implemented in the HW using a scalar for bias
            self.bias_add_op.is_correctable = False

        self.encoding_const = False

    @property
    def pre_acc_shift(self):
        return self.avgpool_op.pre_acc_shift

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.avgpool_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.avgpool_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.avgpool_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @classmethod
    def _validate_elwa(cls, elwa_value):
        if elwa_value:
            raise ValueError(
                f"elementwise_add value was {elwa_value}, "
                f"but expected {not elwa_value} in {cls.__name__} initialization",
            )

    def import_weights(self, layer_params: LayerParams):
        # avgpool doesn't really have weights, but since we implement it with conv for its core, it does.
        # TODO: do we want to treat avgpool as trainable layer or not?
        #       if yes - we need to load the params here. if no - leave it as is.
        #       but in the current stage - we can't load params after bias correction,
        #       because we allow modification of the avgpool bias.
        #       SDK-23644
        self.avgpool_op.import_weights(layer_params)
        if "bias" in layer_params.keys():
            # Might happened when loading weights after bias correction
            self.bias_add_op.import_weights(layer_params["bias"])
        self.act_op.import_weights(layer_params)

    def get_numeric_kernel_np(self):
        return self.avgpool_op.final_quantized_kernel.numpy() * (
            np.ones(self.avgpool_op.kernel_shape_prod * self.input_shape[-1])
        )

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "strides": [1, 1, 1, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "SAME",
            "activation": "linear",
            "groups": 1,
            "elementwise_add": False,
        }
        return dict(defaults)

    @property
    def bias(self):
        return self.bias_add_op.bias

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        self.avgpool_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_add_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups - we now dont support it but maybe in the future?
        self.act_op.set_quantization_groups(quant_groups)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        if self.act_op.quantization_groups_num > 1:
            raise AccelerasImplementationError(
                f"For layer {self.full_name} we don't support qunatization with quantization groups yet",
            )

        # Create *candidates* for scales and *finalize* pre_acc_shift
        self._enforce_output_encoding()
        pre_act_stats = self.act_op.get_input_stats(0)
        max_final_accumulator_by_channel = np.maximum(np.abs(pre_act_stats.min), np.abs(pre_act_stats.max))

        self.avgpool_op.create_hw_params(max_final_accumulator_by_channel, hw_shifts=hw_shifts)

        self.bias_add_op.pre_acc_shift = self.avgpool_op.pre_acc_shift

        # From accumulator scale candidate, create the "ideal" output factor (*finalized*).
        self.act_op.create_hw_params(self.avgpool_op.output_scale, optimization_target, nudging=False)

        self.enforce_internal_encoding()
        self.bias_add_op.create_hw_params()
        self._has_hw_params = True

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        Note that as we set the avgpool op weight to be an exact power of two, we don't have to set the apu encodings.
        """
        self._enforce_output_encoding()
        self.avgpool_op.enforce_encoding(training=training)
        self.bias_add_op.input_scale = self.avgpool_op.output_scale
        self.bias_add_op.output_scale = self.bias_add_op.input_scales[0]
        self.bias_add_op.input_zero_point = self.avgpool_op.output_zero_point
        self.bias_add_op.enforce_encoding()
        self.act_op.input_scale = self.bias_add_op.output_scale
        self.act_op.input_zero_point = self.bias_add_op.output_zero_point
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        output_scale = self.input_scale * self.output_scale_scalar_dof  # ! rewrites output_scale_scalar_dof via setter
        self.output_op.output_scale = output_scale  # don't invoke output_scale_scalar_dof calculation

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)
        is_max_width_avgpool = kshape[2] == hn_element["input_shapes"][0][2]
        strides = params["strides"][1:3]

        cls._validate_elwa(params["elementwise_add"])

        if params.get("groups", 1) != 1:
            raise ValueError("AvgPool doesn't support group param")
        activation = params.get("activation", "linear")
        layer = cls(
            name=lname,
            pool_size=kshape[1:3],
            strides=strides,
            padding=padding,
            stride_align=stride_align,
            activation=activation,
            is_max_width_avgpool=is_max_width_avgpool,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        if self.avgpool_op.padding_const_value != 0:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def _export_weights(self):
        # avgpool doesn't really have weights, but since we implement it with conv for its core, it does.
        self._logger.debug("Avgpool export_weights was triggered, but nothing happened")
        bias = self.bias_add_op.export_weights()
        if self._is_max_width_avgpool and np.any(bias != bias[0]):
            raise AccelerasValueError(f"Bias vector a in {self.full_name} has to be a scalar. received: {bias}")
        dict_params = {"bias": bias}
        avgpool_params = self.avgpool_op.export_weights()
        dict_params.update(avgpool_params)
        activation_params = self.act_op.export_weights()
        dict_params.update(activation_params)
        return dict_params

    def _create_out_in_scale_ratio(self):
        """
        create the output_scale_scalar_dof
        """
        out_in_scale_ratio = self.output_scale / self.input_scale
        eps = 1e-6
        if out_in_scale_ratio.shape != ():
            if isinstance(out_in_scale_ratio, np.ndarray) and eps < np.max(
                np.abs(out_in_scale_ratio - out_in_scale_ratio[0]) / out_in_scale_ratio[0],
            ):
                # Possible fail case: coming from concat, so input scale is scalar while output is vector..
                raise AccelerasNumerizationError(
                    f"output_scale - input_scale ratio of {self.full_name} should be a scalar"
                )
            # create attribute to be used in scales-training context should it come
            out_in_scale_ratio = out_in_scale_ratio[0]
        self.output_scale_scalar_dof = out_in_scale_ratio

    def get_kernel_np(self):
        return self.avgpool_op.final_quantized_kernel.numpy()

    def get_bias_np(self):
        return self.bias_add_op.bias.numpy()

    @property
    def output_factors(self):
        return self.act_op.output_factor_by_group

    def _get_bias_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            if self._is_max_width_avgpool:
                return {BiasMode.single_scale_decomposition}
            else:
                return self.SUPPORTED_BIAS_MODE - {BiasMode.double_scale_decomposition}
        else:
            return super()._get_bias_mode_supported_in_hw(arch)

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            if self._is_max_width_avgpool:
                return {
                    PrecisionMode.a8_w8,
                    PrecisionMode.a16_w16,
                    PrecisionMode.a8_w8_a8,
                    PrecisionMode.a8_w8_a16,
                    PrecisionMode.a16_w16_a16,
                }
            else:
                return self.SUPPORTED_PRECISION_MODE - {PrecisionMode.a16_w16_a8}
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        is_supported = super()._is_precision_config_supported(precision_mode, bias_mode, arch)
        if not is_supported:
            return False
        return True

    def _get_kernel_bits(self):
        return self.avgpool_op.weight_lossy_elements.kernel.bits

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(f"{self.full_name}/output_scale_scalar_dof:0", EncodingType.Scale, scalar=False, shape=())
        flow.nodes[f"{self.act_op.full_name}/output_factor_by_group:0"]["encoding"].scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not (self.avgpool_op.encoding_const and self.bias_add_op.encoding_const):
            enc.identity(f"{self.avgpool_op.full_name}/mac_shift:0", f"{self.bias_add_op.full_name}/mac_shift:0")
        if self.output_op.encoding_const and self.avgpool_op.encoding_const:
            enc.identity(f"{self.full_name}/output_scale_scalar_dof:0", self.output_scale_scalar_dof)
        else:
            enc.div(
                enc.dummy(0), f"{self.output_op.full_name}/output_scale:0", f"{self.avgpool_op.full_name}/input_scale:0"
            )
            enc.callback(
                f"{self.full_name}/output_scale_scalar_dof:0",
                enc.dummy(0),
                lambda x: x[0] if x.shape != () else x,
                outs_scalar=True,
                outs_shape=(),
            )

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.avgpool_op.pre_acc_shift = encodings[f"{self.avgpool_op.full_name}/mac_shift:0"]
        self.output_scale_scalar_dof = encodings[f"{self.full_name}/output_scale_scalar_dof:0"]

    def verify_layer_inputs_shape(self, input_shapes):
        kernel_size = self.avgpool_op.kernel_size
        strides = self.avgpool_op.strides
        h_in = input_shapes[0][1]
        w_in = input_shapes[0][2]
        if self.avgpool_op.padding != PaddingType.VALID:
            pads = get_tf_same_padding(None, h_in, w_in, kernel_size[0], kernel_size[1], strides[0], strides[1])
            # pads = pad_beg_h, pad_end_h, pad_beg_w, pad_end_w
            h_in = h_in + pads[0] + pads[1]
            w_in = w_in + pads[2] + pads[3]

        h_out = (h_in - kernel_size[0]) / strides[0] + 1
        w_out = (w_in - kernel_size[1]) / strides[1] + 1
        if np.floor(h_out) <= 0 or np.floor(w_out) <= 0:
            raise InvalidInputShape(
                f"Input shapes {input_shapes} doesn't match layer's parameters in {self.full_name}",
                self.full_name,
            )

    def is_global_avgpool(self):
        return self.input_shapes[0][1:3] == self.avgpool_op.kernel_shape[:2]

    def get_macs(self) -> int:
        _, high, widht, chanels_in = self.input_shape
        kh, kw, kc, n_kernels = self.avgpool_op.kernel_shape
        s_h, s_w = self.avgpool_op.strides

        macs = kh * kw * (high // s_h) * (widht // s_w) * chanels_in
        return macs
