import inspect
from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.crosscorrelation_dw_op import CrossCorrelationDWOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PaddingType,
    PrecisionMode,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasNumerizationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import get_kernel_bits_and_sign_by_precision_mode


class HailoCrossCorrelationDW(BaseHailoLayer):
    """
    Implement Hailo cross-correlation depth-wise layer,
        - takes two inputs,
        - the mac behaves as passthru + zp compensation
        - multiply the inputs in the APU
        - activation in the APU
    """

    _hn_type = LayerType.DW

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        strides=(1, 1, 1, 1),
        dilations=(1, 1, 1, 1),
        padding: Union[str, PaddingType] = "VALID",
        stride_align: Union[str, StrideAlignType] = "NW",
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        # Cross-correlation depth-wise atomic opteration
        self.crosscorrelation_dw_op = CrossCorrelationDWOp(
            f"{name}/crosscorrelation_dw_op",
            strides=strides,
            stride_align=stride_align,
            dilations=dilations,
            padding=padding,
            logger=logger,
        )

        # Activation atomic operation
        self.act_op = ActivationOp(
            f"{name}/act_op",
            activation=activation,
            logger=logger,
        )

        # Output atomic operation
        # enabling output quantization even as activation is fully native...
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)

        super().__init__(name=name, logger=logger, **kwargs)

        # APU
        self.output_scale_scalar_dof = 1.0  # initialize dof for the APU (float)
        self._forced_output_scale_scalar_dof = None  # degree of freedom

    @property
    def forced_output_scale_scalar_dof(self):
        return self._forced_output_scale_scalar_dof

    @forced_output_scale_scalar_dof.setter
    def forced_output_scale_scalar_dof(self, forced_output_scale_scalar_dof):
        self._forced_output_scale_scalar_dof = forced_output_scale_scalar_dof

    @property
    def pre_acc_shift(self):
        return self.crosscorrelation_dw_op.pre_acc_shift

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        in2 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.crosscorrelation_dw_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.crosscorrelation_dw_op, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(in2, self.crosscorrelation_dw_op, DataPath.LAYER_IN_WEIGHTS, input_index=1)
        layer_flow.add_edge(self.crosscorrelation_dw_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)

        return layer_flow

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        Assumes that we already have the inputs & output, and calculates all the HW params in the middle.
        """
        # forward -->
        self.crosscorrelation_dw_op.enforce_encoding()
        self.act_op.input_scales = [self.crosscorrelation_dw_op.output_scale]
        self.act_op.input_zero_points = [self.crosscorrelation_dw_op.output_zero_point]
        # backward <--
        self._enforce_output_encoding()
        # Setting the APU
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def import_weights(self, layer_params: LayerParams):
        """
        load parameters to the layer. currently, it doesn't to anything.

        Args:
            layer_params: layer's params from the npz

        """
        self.crosscorrelation_dw_op.import_weights(layer_params)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        return self.act_op.export_weights()

    def _export_layer_metadata(self):
        export_vals = super()._export_layer_metadata()
        if self.forced_output_scale_scalar_dof is not None:
            export_vals["forced_output_scale_scalar_dof"] = self.forced_output_scale_scalar_dof
        return export_vals

    def _import_layer_metadata(self, npz):
        self.forced_output_scale_scalar_dof = npz.get("forced_output_scale_scalar_dof", None)
        return super()._import_layer_metadata(npz)

    @classmethod
    def get_default_params(cls):
        """
        Automatically get the method's default arguments (via __init__). It saves the trouble
        (and affiliate bugs) of forgetting to update this function every time one changes something
        in __init__.
        """
        signature = inspect.signature(cls.__init__)
        return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        """
        Returns values from an hn dict that is taken from an hn file.
        """
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))

        layer = cls(
            name=lname,
            strides=params["strides"],
            padding=params["padding"],
            dilations=params["dilations"],
            stride_align=params["stride_align"],
            activation=params["activation"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def _force_output_scale(self):
        # set output scale to be the product of the (self.input_scales[0] * self.input_scales[1])*forced_output_scale_scalar_dof
        if self.forced_output_scale_scalar_dof is not None and self.output_scale.shape != 0:
            self.set_output_scale(self.input_scales[0] * self.input_scales[1] * self.forced_output_scale_scalar_dof, 0)

    def _create_out_in_scale_ratio(self):
        """
        Create the output_scale_scalar_dof. This is a degree of freedom between the layer's output to its input;
        for this specific layer, because there are 2 inputs for crosscorrelation_op, we get the ratio between the
        layer's output and the output_scales of crosscorrelation_op.
        """
        out_in_scale_ratio = self.output_scale / (self.input_scales[0] * self.input_scales[1])  # APU dof

        eps = 1e-6
        if out_in_scale_ratio.shape != ():
            if isinstance(out_in_scale_ratio, np.ndarray) and eps < np.max(
                np.abs(out_in_scale_ratio - out_in_scale_ratio[0]) / out_in_scale_ratio[0],
            ):
                # Possible fail case: coming from concat, so input scale is scalar while output is vector..
                raise AccelerasNumerizationError(
                    f"output_scale - input_scale ratio of {self.full_name} should be a scalar"
                )
            # create attribute to be used in scales-training context should it come
            out_in_scale_ratio = out_in_scale_ratio[0]
        self.output_scale_scalar_dof = out_in_scale_ratio

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        self.crosscorrelation_dw_op.create_hw_params(preact_limvals=self.get_preact_limvals()[0], hw_shifts=hw_shifts)
        self.act_op.create_hw_params(
            self.crosscorrelation_dw_op.output_scale,
            optimization_target,
            nudging=False,
        )
        self.enforce_internal_encoding()

    def enforce_io_encoding(self, training=False, **kwargs):
        """Enforce encoding between layer inputs and layer output, after the APU(=activation)"""
        self.output_op.output_scale = (
            self.input_scales[0]
            * self.input_scales[1]
            * (
                self.output_scale_scalar_dof
                if not self.forced_output_scale_scalar_dof
                else self.forced_output_scale_scalar_dof
            )
        )

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        mac_data_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_config.precision_mode)
        self.create_quant_element_by_data_path(DataPath.MAC_DATA, mac_data_bits)
        self.act_op.create_weight_quant_element(optimization_target)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization
