from typing import Union

import numpy as np
from tensorflow.keras.layers import InputSpec

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.mock_conv_op import MockConvOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoStandaloneActivation(BaseHailoLayer):
    """
    Simulate a layer dedicated solely to computing activation
    Creates a BaseHailoLayer containing a:
    1. stand_alone_conv
    2. stand_alone_conv
    for the specified activation type.
    # TODO: https://hailotech.atlassian.net/browse/SDK-22986
    Args:
        activation: `str` activation type. Must be included in the values of enum
        hailo_model_optimization.acceleras.utils.acceleras_definitions.ActivationType.
        name: name of layer
        logger: logger to be passed to layer

    Example:
            >>> import tensorflow as tf
        >>> sigmoid_layer = HailoStandaloneActivation(activation='sigmoid')
        >>> rand_data = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> sigmoid_result = sigmoid_layer(rand_data)

    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _output_scale_scalar_dof: float

    _hn_type = LayerType.ACTIVATION

    def __init__(
        self,
        name: str,
        activation: Union[str, callable, ActivationType] = "linear",
        bias_initializer=None,
        trainable=False,
        logger=None,
        **kwargs,
    ):
        self.mock_conv = MockConvOp(f"{name}/mock_op", logger=logger)
        self.bias_add_op = AddBiasOp(
            f"{name}/bias_add_op",
            bias_initializer=bias_initializer,
            trainable=trainable,
            is_correctable=False,
            logger=logger,
        )
        self.act_op = ActivationOp(f"{name}/act_op", activation=activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)

        self.input_spec = InputSpec(ndim=4)
        self.output_scale_scalar_dof = 1
        self._forced_output_scale_scalar_dof = None
        self.encoding_const = False

    @property
    def pre_acc_shift(self):
        return self.mock_conv.pre_acc_shift

    @property
    def forced_output_scale_scalar_dof(self):
        return self._forced_output_scale_scalar_dof

    @forced_output_scale_scalar_dof.setter
    def forced_output_scale_scalar_dof(self, forced_output_scale_scalar_dof):
        self._forced_output_scale_scalar_dof = forced_output_scale_scalar_dof

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        """
        Class method to create HailoStandaloneActivation from hn file
        hn file must contain a single standalone activation
        Args:
            lname: `str` layer name
            hn_element: `file object` created from opening an hn file.

        Example:
            >>> stand_alone_act = HailoStandaloneActivation()
        >>> with open('sigmoid_hn_file.hn') as sigmoid_hn_element:
        ...     sigmoid_layer = stand_alone_act.from_hn(sigmoid_hn_element)
        >>> rand_data = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> sigmoid_result = sigmoid_layer(rand_data)

        """
        activation = hn_element.get("params", {}).get("activation", "linear")
        layer = cls(name=lname, activation=activation, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        params = {}
        params["activation"] = self.act_op.act_name.value
        if "params" not in self._hn_element:
            self._hn_element["params"] = {}
        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.mock_conv)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.mock_conv, DataPath.LAYER_IN)
        layer_flow.add_edge(self.mock_conv, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def import_weights(self, layer_params: LayerParams):
        self.mock_conv.import_weights(layer_params)
        bias = layer_params.get("bias", None)
        if bias is not None:
            self.import_native_bias(bias)
        self._load_activation(layer_params)

    @property
    def is_changing_bias_supported(self):
        return True

    def export_native_bias(self):
        bias = self.bias_add_op.export_weights()
        return bias

    def import_native_bias(self, bias):
        self.bias_add_op.import_weights(bias)

    def _load_activation(self, layer_params):
        self.act_op.import_weights(layer_params)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.activation, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.activation, is_source=False)

    def enforce_io_encoding(self, training=False, **kwargs):
        self.output_op.output_scale = self.input_scale * self.output_scale_scalar_dof

    @property
    def bias(self):
        return self.bias_add_op.bias

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode, force_signed_kernel=True)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        self.mock_conv.create_weight_quant_element(kernel_bits, signed)
        self.bias_add_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups
        self.act_op.set_quantization_groups(quant_groups)

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        Note that as we set the mock_conv op weight to be an exact power of two, we don't have to set the apu encodings.
        """
        self._enforce_output_encoding()
        self.mock_conv.enforce_encoding(training=training)
        self.bias_add_op.input_scales = [self.mock_conv.output_scale]
        self.bias_add_op.output_scale = self.bias_add_op.input_scales[0]
        self.bias_add_op.input_zero_points = [self.mock_conv.output_zero_point]
        self.bias_add_op.enforce_encoding()
        self.act_op.input_scales = [self.bias_add_op.output_scale]
        self.act_op.input_zero_points = [self.bias_add_op.output_zero_point]
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def _export_weights(self):
        params = {}
        params.update(self.mock_conv.export_weights())
        params["bias"] = self.export_native_bias()
        params.update(self.act_op.export_weights())
        return params

    def _export_layer_metadata(self):
        export_vals = super()._export_layer_metadata()
        if self.forced_output_scale_scalar_dof is not None:
            export_vals["forced_output_scale_scalar_dof"] = self.forced_output_scale_scalar_dof
        return export_vals

    def _import_layer_metadata(self, npz):
        self.forced_output_scale_scalar_dof = npz.get("forced_output_scale_scalar_dof", None)
        return super()._import_layer_metadata(npz)

    def neg_weights(self):
        self.mock_conv.kernel = -1

    def _force_output_scale(self):
        # froce output_scale and output_zero_point to be as we want
        if self.forced_output_scale_scalar_dof is not None and self.output_scale.shape != ():
            # TODO change it to be the real output factor
            self.set_output_scale(self.input_scale * self.forced_output_scale_scalar_dof, 0)
            self.set_output_zero_point(self.input_zero_point, 0)

    def _create_out_in_scale_ratio(self):
        """
        create the output_scale_scalar_dof
        """
        _out_in_scale_ratio = self.output_scale / self.input_scales
        self.output_scale_scalar_dof = np.max(_out_in_scale_ratio)

    def get_kernel_np(self):
        return self.mock_conv.kernel_q.numpy()

    def get_bias_np(self):
        return self.bias_add_op.bias.numpy()

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        if self.act_op.quantization_groups_num > 1:
            raise AccelerasImplementationError(
                f"For layer {self.full_name} we don't support qunatization with quantization groups yet",
            )
        self._enforce_output_encoding()
        kernel_value = self._get_quantized_value_base_on_activation(self.act_op.act_name)
        pre_acc_shift = hw_shifts[0] if hw_shifts is not None else None
        self.mock_conv.create_hw_params(kernel_value, pre_acc_shift=pre_acc_shift)
        self.bias_add_op.pre_acc_shift = self.mock_conv.pre_acc_shift

        # From accumulator scale candidate, create the "ideal" output factor (*finalized*).
        self.act_op.create_hw_params(self.mock_conv.output_scale, optimization_target, nudging=False)

        self.enforce_internal_encoding()
        self.bias_add_op.create_hw_params()
        self._has_hw_params = True

    def get_numeric_kernel_np(self):
        return self.mock_conv.kernel_q.numpy() * (np.ones(self.mock_conv.kernel_shape * self.input_shape[-1]))

    def _get_quantized_value_base_on_activation(self, activation_name):
        if self.get_precision_mode() == PrecisionMode.a16_w16_a8:
            return 2**15 - 1
        if activation_name == ActivationType.INV_POS or activation_name == ActivationType.MINUS_INV_POS:
            return 126
        else:
            return 64
        # find the best quantize param for layer

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(f"{self.full_name}/output_scale_scalar_dof:0", EncodingType.Scale, scalar=False, shape=())
        flow.nodes[f"{self.act_op.full_name}/output_factor_by_group:0"]["encoding"].scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not (self.mock_conv.encoding_const and self.bias_add_op.encoding_const):
            enc.identity(f"{self.mock_conv.full_name}/mac_shift:0", f"{self.bias_add_op.full_name}/mac_shift:0")

        if self.mock_conv.encoding_const and self.output_op.encoding_const:
            enc.identity(f"{self.full_name}/output_scale_scalar_dof:0", self.output_scale_scalar_dof)
        else:
            enc.div(
                enc.dummy(0), f"{self.output_op.full_name}/output_scale:0", f"{self.mock_conv.full_name}/input_scale:0"
            )
            enc.callback(
                f"{self.full_name}/output_scale_scalar_dof:0",
                enc.dummy(0),
                lambda x: x[0] if x.shape != () else x,
                outs_scalar=True,
                outs_shape=(),
            )

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.output_scale_scalar_dof = encodings[f"{self.full_name}/output_scale_scalar_dof:0"]

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return self.SUPPORTED_PRECISION_MODE
        else:
            return super()._get_precision_mode_supported_in_hw(arch)
