import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType, SpaceToDepthType, SpaceToDepthTypes
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification

SPACE_TO_DEPTH_TYPE = "space_to_depth_type"


class SpaceToDepthLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    _DEFAULT_BLOCK_SIZE = [2, 2]

    def __init__(self):
        super().__init__()
        self._op = LayerType.space_to_depth
        self._block_sizes = self._DEFAULT_BLOCK_SIZE
        self._space_to_depth_type = SpaceToDepthType.classic_dcr

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        block_sizes,
        output_shapes=None,
        space_to_depth_type=SpaceToDepthType.classic_dcr,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer._block_sizes = block_sizes
        layer._space_to_depth_type = space_to_depth_type
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_SPACE_TO_DEPTH
        node.block_sizes.extend(self.block_sizes)
        node.space_to_depth_type = pb_wrapper.SPACE_TO_DEPTH_TYPE_TYPE_TO_PB[self.space_to_depth_type]
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.block_sizes = pb.block_sizes[:]
        layer.space_to_depth_type = pb_wrapper.SPACE_TO_DEPTH_TYPE_PB_TO_TYPE[pb.space_to_depth_type]
        cls._validate_input(layer)
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.block_sizes = old_layer.block_sizes.copy()
        layer.space_to_depth_type = old_layer.space_to_depth_type
        return layer

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "params" in hn:
            if "block_size" in hn["params"]:
                layer.block_sizes = [hn["params"]["block_size"], hn["params"]["block_size"]]
            elif "block_sizes" in hn["params"]:
                layer.block_sizes = hn["params"]["block_sizes"]
            else:
                layer.block_sizes = cls._DEFAULT_BLOCK_SIZE
            if "space_to_depth_type" in hn["params"]:
                layer.space_to_depth_type = SpaceToDepthTypes[hn["params"]["space_to_depth_type"]]
            else:
                layer.space_to_depth_type = SpaceToDepthType.classic_dcr

        cls._validate_input(layer)
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["block_sizes"] = self._block_sizes
        result["params"]["space_to_depth_type"] = self._space_to_depth_type.value
        return result

    def _calc_output_shape(self):
        type(self)._validate_input(self)
        out_h = self.input_shape[1] // self._block_sizes[0]
        out_w = self.input_shape[2] // self._block_sizes[1]
        out_f = self.input_shape[3] * (self._block_sizes[0] * self._block_sizes[1])
        return [-1, 1, out_h * out_w, out_f] if self._spatial_flatten_output else [-1, out_h, out_w, out_f]

    @classmethod
    def _validate_input(cls, layer):
        height = layer.input_shape[1]
        width = layer.input_shape[2]
        block_height = layer.block_sizes[0]
        block_width = layer.block_sizes[1]

        if layer.space_to_depth_type.value not in [s2d_type.value for s2d_type in iter(SpaceToDepthType)]:
            raise UnsupportedModelError(f"{layer.full_name_msg}: None-Existent space_to_depth_type")
        if layer.space_to_depth_type.value == SpaceToDepthType.serial.value and (block_height, block_width) != (16, 16):
            raise UnsupportedModelError(
                f"{layer.full_name_msg}: the block sizes are: ({block_height}), ({block_width}). should be (16, 16) in SERIAL type",
            )
        if height % block_height != 0:
            raise UnsupportedModelError(
                f"{layer.full_name_msg}: Input height size ({height}) should  be multiply of {block_height}",
            )
        if width % block_width != 0:
            raise UnsupportedModelError(
                f"{layer.full_name_msg}: Input width size ({width}). should be multiply of {block_width}",
            )

    @property
    def block_sizes(self):
        return self._block_sizes

    @block_sizes.setter
    def block_sizes(self, block_sizes):
        self._block_sizes = block_sizes

    @property
    def space_to_depth_type(self):
        return self._space_to_depth_type

    @space_to_depth_type.setter
    def space_to_depth_type(self, space_to_depth_type):
        self._space_to_depth_type = space_to_depth_type

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
