import copy

import numpy as np
from past.utils import old_div

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType, PaddingType, PaddingTypes
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FeatureInterleaveLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.feature_interleave
        self._rate = None
        self._padding = None

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_FEATURE_INTERLEAVE
        node.strides.height, node.strides.width = self._rate, self._rate
        node.padding = pb_wrapper.PADDING_TYPE_TO_PB[self.padding]
        return node

    def _calc_output_shape(self):
        if (self.input_shape[3] % (np.power(self._rate, 2))) != 0:
            raise UnsupportedModelError(
                f"{self.full_name_msg} output_features does not divide by interleaving_rate^2. This is not allowed",
            )
        h_output_shape = self.input_shape[1] * self._rate
        w_output_shape = self.input_shape[2] * self._rate
        if "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.deconv,
            DefuseType.compute_lanes,
        ]:
            output_f = self.defuse_features
        else:
            output_f = self.input_shape[3]
        if self.padding == PaddingType.deconv:
            h_output_shape -= self._rate
            w_output_shape -= self._rate
        return [
            self.input_shape[0],
            h_output_shape,
            w_output_shape,
            int(old_div(output_f, (np.power(self._rate, 2)))),
        ]

    def _get_output_shape(self, validate=False, layer_name=None, layer_index=None):
        return self._calc_output_shape()

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["rate"] = self._rate
        result["params"]["padding"] = self._padding.value
        return result

    @property
    def padding(self):
        return self._padding

    @property
    def rate(self):
        return self._rate

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._rate = pb.strides.height
        layer._padding = pb_wrapper.PADDING_PB_TO_TYPE[pb.padding]

        return layer

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer._rate = hn["params"]["rate"]
        layer._padding = PaddingTypes[hn["params"]["padding"]]
        return layer

    # TODO: set proper instead of undefined
    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
