import inspect
from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.mock_conv_op import MockConvOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.proposal_generator_op import ProposalGeneratoOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_nms import NMS_FORCED_BITS
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import get_kernel_bits_and_sign_by_precision_mode


class HailoProposalGenerator(BaseHailoLayer):
    SUPPORTED_PRECISION_MODE = {PrecisionMode.a16_w16, PrecisionMode.a8_w8_a16, PrecisionMode.a16_w16_a16}
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = True

    _hn_type = LayerType.PROPOSAL_GENERATOR
    OP_NAME = "proposal_generator_op"

    def __init__(
        self,
        name: str,
        input_division_factor: int = 1,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        # Proposal generato operation
        self.proposal_op = ProposalGeneratoOp(
            f"{name}/proposal_generator_op",
            input_division_factor=input_division_factor,
            logger=logger,
        )
        self.mock_conv = MockConvOp(f"{name}/mock_op", logger=logger)
        # Activation atomic operation
        self.act_op = ActivationOp(f"{name}/act_op", activation=activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        # We don't want to change the scales over the proposal_op, so we set the scale scalar dof to 1.
        # The parameter output_scale_scalar_dof is initialized but not used (implicitly, to
        # enforce the same input and output scales for the proposal_op).
        self.output_scale_scalar_dof = 1
        super().__init__(name=name, logger=logger, **kwargs)

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        self.proposal_op.import_weights(**param_dict)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        params = {}
        params.update(self.act_op.export_weights())
        return params

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        in2 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.proposal_op)
        layer_flow.add_node(self.mock_conv)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.proposal_op, input_index=0, data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(in2, self.proposal_op, input_index=1, data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self.proposal_op, self.mock_conv, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.mock_conv, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, input_index=0, data_path=DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, input_index=0, data_path=DataPath.LAYER_OUT)

        return layer_flow

    @classmethod
    def get_default_params(cls):
        """Automatically get the method's default arguments (via __init__). It saves the trouble
        (and affiliate bugs) of forgetting to update this function every time one changes something
        in __init__.
        """
        signature = inspect.signature(cls.__init__)
        return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))

        layer = cls(
            name=lname,
            input_division_factor=params["input_division_factor"],
            activation=params["activation"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        hn_element = super().to_hn()
        hn_element["params"].update(
            {
                "activation": self.act_op.act_name.value,
            }
        )
        return hn_element

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_io_encoding(self, train_scales=False, training=False, **kwargs):
        self.proposal_op.enforce_encoding()
        self.output_op.output_scale = self.proposal_op.output_scale
        self.set_output_scale(self.proposal_op.output_scale, 0)
        self.set_output_zero_point(self.proposal_op.output_zero_point, 0)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        if self.act_op.quantization_groups_num > 1:
            raise AccelerasImplementationError(
                f"For layer {self.full_name} we don't support qunatization with quantization groups yet"
            )
        self._enforce_output_encoding()
        self.enforce_internal_encoding()
        kernel_value = self._get_quantized_value_base_on_activation(self.act_op.act_name)
        self.mock_conv.create_hw_params(kernel_value=kernel_value)
        self.act_op.create_hw_params(self.mock_conv.output_scale, optimization_target, nudging=False)
        self.act_op.pl_approximate(self.mock_conv.output_scale, optimization_target)
        self.act_op.update_mantissa_exponent_decomposition()

        # Change activation to CLIP to ensure that the output will not exceed the range [0, 2**12]
        # for the nms layer (see create_output_encoding_candidates).
        self.act_op.act_name, self.act_op.act_func = self.act_op.create_act_name_and_func(ActivationType.CLIP)
        act_native_params = self.act_op.act_numeric_params

        clip_min, clip_max = self._get_clip_min_max()
        act_native_params["clip_min"] = clip_min
        act_native_params["clip_max"] = clip_max

        self.act_op.import_weights(act_native_params)
        self.act_op.pl_approximate(self.mock_conv.output_scale, optimization_target)
        self.act_op.update_mantissa_exponent_decomposition()
        self.act_op.enforce_encoding()

        self.enforce_internal_encoding()

    def _get_quantized_value_base_on_activation(self, activation_name):
        return 64

    def enforce_internal_encoding(self, **kwargs):
        """Assumes that we already have the inputs & output, and calculates all the HW params in the middle."""
        super().enforce_internal_encoding()
        self._enforce_output_encoding()
        self.proposal_op.enforce_encoding()

        # Propagate through mock_conv
        self.mock_conv.input_scales = [self.proposal_op.output_scale]
        self.mock_conv.input_zero_points = [self.proposal_op.output_zero_point]
        self.mock_conv.enforce_encoding()  # don't train these weights

        # Setting the APU
        self.act_op.input_scales = [self.mock_conv.output_scale]
        self.act_op.input_zero_points = [self.mock_conv.output_zero_point]
        self.act_op.enforce_encoding()

    def is_trainable(self) -> bool:
        return False

    def is_differentiable(self) -> bool:
        return False

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            supported_precision_mode = {PrecisionMode.a16_w16, PrecisionMode.a8_w8_a16, PrecisionMode.a16_w16_a16}
        else:
            supported_precision_mode = super()._get_precision_mode_supported_in_hw(arch)
        return supported_precision_mode

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)

        self.mock_conv.create_weight_quant_element(kernel_bits, signed)
        self.act_op.create_weight_quant_element(optimization_target)
        self.act_op.set_quantization_groups(quant_groups)

    @property
    def has_activation(self):
        """For the grouping of components (equiv_flow -> get_groups_components()), we wish to
        have all the post-processing as one group component. This way, the matching scales is done
        in the right way. For this reason, we set the has_activation property to false.
        """
        return False

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    @classmethod
    def get_default_precision_mode(cls):
        """Set default precision mode. Used, for example, for create mixed precision."""
        return PrecisionMode.a8_w8_a16

    def _get_clip_min_max(self):
        clip_min = 0.0
        output_bits = self.output_op.output_lossy_elements[0].bits
        clip_max = np.max(self.output_scale) * (2 ** np.minimum(NMS_FORCED_BITS, output_bits) - 1)
        return clip_min, clip_max
