import copy
from collections import OrderedDict

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationTypes, DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.batch_norm import BatchNormLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FusedBatchNormLayer(BatchNormLayer, LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.batch_norm

    @property
    def padding(self):
        return self._padding

    @property
    def dilations(self):
        return [1, 1, 1, 1]

    @property
    def strides(self):
        return [1, 1, 1, 1]

    @property
    def kernel_shape(self):
        return [1, 1, 1, self.input_features]

    def move_params(self, layer):
        super().move_params(layer)
        if layer.op == LayerType.base_batch_norm:
            self._bn_info = layer._bn_info

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        # does not inherit from LayerWithParams because we don't allow multiple inheritance
        result["params"] = OrderedDict()
        result["params"]["activation"] = self._activation.value
        result["params"]["elementwise_add"] = self.ew_add_enabled
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BATCH_NORM
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        node.kernel_shape.height, node.kernel_shape.width, _, node.kernel_shape.features = self.kernel_shape
        node.strides.height, node.strides.width = 1, 1
        node.ew_add = self.ew_add_enabled
        return node

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.skip, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.featurewise, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.supported

    @classmethod
    def from_hn(cls, hn):
        if "params" not in hn:
            raise UnsupportedModelError(
                f'layer {hn["name"]} of type {hn["type"]} requires params, but the HN does not contain them',
            )

        layer = super().from_hn(hn)
        layer._activation = ActivationTypes[hn["params"]["activation"]]
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.bn_info = old_layer.bn_info
        return layer

    def _calc_output_shape(self):
        output_f = self.input_shape[-1]
        if "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.compute_lanes,
        ]:
            output_f = self.defuse_features
        return [-1, *self.input_shape[1:-1], output_f]
