import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationTypes, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.dense import DenseLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FusedDenseLayer(
    DenseLayer,
    LayerWithActivation,
):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._bn_info = None
        self._pre_layer_bn = False
        self._bn_enabled = False
        self._op = LayerType.dense
        self._should_squeeze_kernel = False

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["batch_norm"] = self._bn_enabled
        if self.pre_layer_bn:
            result["params"]["pre_layer_batch_norm"] = True
        result["params"]["activation"] = self._activation.value
        if self._should_squeeze_kernel:
            result["params"]["should_squeeze_kernel"] = self._should_squeeze_kernel
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DENSE
        node.batch_norm = self._bn_enabled
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        node.should_squeeze_kernel = self._should_squeeze_kernel
        return node

    def move_params(self, layer):
        super().move_params(layer)
        if layer.op in [LayerType.base_batch_norm, LayerType.conv]:
            self._bn_info = layer.bn_info

        if self.should_squeeze_kernel:
            # in case this dense layer used to be a conv1x1 (switched when comes after global avgpool)
            # we need to reshape the kernel to match the dense layer kernel shape rank
            self._kernel = self._kernel.reshape(self.kernel_shape)

    @property
    def bn_enabled(self):
        return self._bn_enabled

    @bn_enabled.setter
    def bn_enabled(self, bn_enabled):
        self._bn_enabled = bn_enabled

    @property
    def bn_info(self):
        return self._bn_info

    @property
    def pre_layer_bn(self):
        return self._pre_layer_bn

    @pre_layer_bn.setter
    def pre_layer_bn(self, pre_layer_bn):
        self._pre_layer_bn = pre_layer_bn

    @property
    def should_squeeze_kernel(self):
        return self._should_squeeze_kernel

    @should_squeeze_kernel.setter
    def should_squeeze_kernel(self, should_squeeze_kernel):
        self._should_squeeze_kernel = should_squeeze_kernel

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if hn["params"]["batch_norm"]:
            layer._bn_enabled = True
        if "pre_layer_batch_norm" in hn["params"]:
            layer._pre_layer_bn = hn["params"]["pre_layer_batch_norm"]
        if "should_squeeze_kernel" in hn["params"]:
            layer._should_squeeze_kernel = hn["params"]["should_squeeze_kernel"]
        layer._activation = ActivationTypes[hn["params"]["activation"]]
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._bn_enabled = pb.batch_norm
        layer._pre_layer_bn = pb.pre_layer_batch_norm
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        layer._should_squeeze_kernel = pb.should_squeeze_kernel
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.bn_enabled = old_layer.bn_enabled

        return layer

    @property
    def macs(self):
        return self.input_features * self.output_features

    @property
    def ops(self):
        """
        Return the number of multiplications and additions.
        Multiplications are input_features * output_features
        Additions are (input_features - 1) * output_features
        and bias addition is output_features.
        """
        return self.input_features * self.output_features * 2

    @property
    def weights(self):
        return self.input_features * self.output_features + self.output_features

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def ibc_supported(self):
        return LayerSupportStatus.supported
