import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, ActivationTypes, DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ActivationLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = {
        ActivationType.linear: False,
        ActivationType.relu: False,
        ActivationType.relu6: False,
        ActivationType.relu1: False,
        ActivationType.leaky: True,
        ActivationType.elu: False,
        ActivationType.sigmoid: False,
        ActivationType.exp: False,
        ActivationType.tanh: False,
        ActivationType.threshold: True,
        ActivationType.biased_delta: True,
        ActivationType.prelu: True,
        ActivationType.softplus: False,
        ActivationType.silu: False,
        ActivationType.gelu: False,
        ActivationType.mish: False,
        ActivationType.inv_pos: False,
        ActivationType.hardswish: False,
        ActivationType.swish: True,
        ActivationType.less: True,
        ActivationType.log: False,
        ActivationType.sqrt: False,
        ActivationType.hardsigmoid: True,
        ActivationType.clip: True,
        ActivationType.inv_sqrt: False,
        ActivationType.softsign: False,
        ActivationType.delta: False,
        ActivationType.greater: True,
        ActivationType.pow: True,
        ActivationType.hdr_compression: False,
        ActivationType.relu_positive_square: False,
        ActivationType.pwl: True,
        ActivationType.exp_decompose: False,
        ActivationType.shift: False,
    }

    _REQUIRES_QUANTIZED_WEIGHTS = {
        ActivationType.linear: False,
        ActivationType.relu: False,
        ActivationType.relu6: False,
        ActivationType.relu1: False,
        ActivationType.leaky: True,
        ActivationType.elu: False,
        ActivationType.sigmoid: False,
        ActivationType.exp: False,
        ActivationType.tanh: False,
        ActivationType.threshold: True,
        ActivationType.biased_delta: True,
        ActivationType.prelu: True,
        ActivationType.softplus: False,
        ActivationType.silu: False,
        ActivationType.gelu: False,
        ActivationType.mish: False,
        ActivationType.inv_pos: False,
        ActivationType.hardswish: False,
        ActivationType.swish: True,
        ActivationType.less: True,
        ActivationType.log: False,
        ActivationType.sqrt: False,
        ActivationType.hardsigmoid: True,
        ActivationType.clip: True,
        ActivationType.inv_sqrt: False,
        ActivationType.softsign: False,
        ActivationType.delta: False,
        ActivationType.greater: True,
        ActivationType.pow: True,
        ActivationType.hdr_compression: False,
        ActivationType.relu_positive_square: False,
        ActivationType.pwl: True,
        ActivationType.exp_decompose: False,
        ActivationType.shift: False,
    }

    # maps the activation types to their parameters saved in NPZ
    ACTIVATION_TO_PARAPMS = {
        ActivationType.leaky: ["leaky_alpha"],
        ActivationType.threshold: ["activation_threshold"],
        ActivationType.biased_delta: ["activation_delta_bias"],
        ActivationType.prelu: ["prelu_slope"],
        ActivationType.swish: ["swish_beta"],
        ActivationType.less: ["activation_less_values"],
        ActivationType.hardsigmoid: ["hardsigmoid_alpha", "hardsigmoid_beta"],
        ActivationType.clip: ["clip_min", "clip_max"],
        ActivationType.greater: ["activation_greater_values"],
        ActivationType.pow: ["pow_exponent"],
    }

    _IS_RANK3_SUPPORTED = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.base_activation

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        activation,
        leaky_alpha=None,
        activation_threshold=None,
        delta_bias=None,
        output_shapes=None,
        prelu_slope=None,
        swish_beta=None,
        activation_values=None,
        hardsigmoid_alpha=None,
        hardsigmoid_beta=None,
        clip_min=None,
        clip_max=None,
        pow_exponent=None,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.activation = activation
        if activation == ActivationType.leaky:
            layer.leaky_alpha = leaky_alpha
        elif activation == ActivationType.prelu:
            layer.prelu_slope = prelu_slope
        elif activation == ActivationType.threshold:
            layer.activation_threshold = activation_threshold
        elif activation == ActivationType.biased_delta:
            layer.activation_delta_bias = delta_bias
        elif activation == ActivationType.swish:
            layer.swish_beta = swish_beta
        elif activation == ActivationType.less:
            layer.activation_less_values = activation_values
        elif activation == ActivationType.greater:
            layer.activation_greater_values = activation_values
        elif activation == ActivationType.hardsigmoid:
            layer.hardsigmoid_alpha = hardsigmoid_alpha
            layer.hardsigmoid_beta = hardsigmoid_beta
        elif activation == ActivationType.clip:
            layer.clip_min = clip_min
            layer.clip_max = clip_max
        elif activation == ActivationType.pow:
            layer.pow_exponent = pow_exponent
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_ACTIVATION
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self.activation]
        return node

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["activation"] = self.activation.value
        return result

    @staticmethod
    def requires_weights(activation):
        if activation not in ActivationLayer._REQUIRES_NATIVE_WEIGHTS:
            # does not specify whether native weights are required, assuming False"
            return False

        return ActivationLayer._REQUIRES_NATIVE_WEIGHTS[activation]

    @property
    def requires_native_weights(self):
        if self._activation not in ActivationLayer._REQUIRES_NATIVE_WEIGHTS:
            self._logger.warning(
                f"Layer {self.name} of activation type {self._activation.value} does not specify whether native weights "
                "are required. Assuming False.",
            )
            return False

        return ActivationLayer._REQUIRES_NATIVE_WEIGHTS[self._activation]

    @property
    def requires_quantized_weights(self):
        if self._activation not in ActivationLayer._REQUIRES_QUANTIZED_WEIGHTS:
            self._logger.warning(
                f"Layer {self.name} of activation type {self._activation.value} does not specify whether quantized weights "
                "are required. Assuming False.",
            )
            return False

        return ActivationLayer._REQUIRES_QUANTIZED_WEIGHTS[self._activation]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer._activation = ActivationTypes[hn["params"]["activation"]]
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op not in [LayerType.rnn, LayerType.lstm, LayerType.equal]:
            layer.activation = old_layer.activation
        return layer

    def _calc_output_shape(self):
        if self._transpose_output_width_features and len(self.input_shape) == 4:
            return [self.input_shape[0], self.input_shape[1], self.input_shape[3], self.input_shape[2]]

        if (
            self.defuse_type is DefuseType.spatial_w
            and "defuse_input_width" in self.defuse_params
            and self.defuse_input_width != 0
        ):
            return [self.input_shape[0], self.input_shape[1], self.defuse_input_width, self.input_shape[3]]

        return self.input_shape
