import copy

from past.utils import old_div

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import DepthToSpaceType, DepthToSpaceTypes, HnStage, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class DepthToSpaceLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.depth_to_space
        self._block_sizes = None
        self._depth_to_space_type = DepthToSpaceType.dcr
        self._first_reshape = None
        self._second_reshape = None
        self._width_slice = None
        self._height_slice = None

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        block_size,
        output_shapes=None,
        first_reshape=None,
        second_reshape=None,
        depth_to_space_type=DepthToSpaceType.dcr,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.depth_to_space_type = depth_to_space_type

        if isinstance(block_size, list):
            layer.block_sizes = block_size
        else:
            layer.block_sizes = [block_size, block_size]

        # some D2S layers are implemented using a reshape->transpose->reshape flow, that needs to be validated
        if first_reshape and second_reshape:
            layer._first_reshape = first_reshape
            layer._second_reshape = second_reshape
        return layer

    def update_output_shapes(self, **kwargs):
        hn_stage = kwargs["hn_stage"]
        if hn_stage == HnStage.PRE_FUSED:
            self.validate_parsed_reshaped_features()
        super().update_output_shapes()

    def _calc_output_shape(self):
        type(self)._validate_input_features_size(self)
        self._validate_slice_params()
        if self._width_slice and self._width_slice[:-1] != [0, 0]:
            width = self._width_slice[1] - self._width_slice[0]
        else:
            in_width = self.input_shape[2]
            if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
                in_width = self.defuse_input_width
            width = self._block_sizes[1] * in_width

        if self._height_slice and self._height_slice[:-1] != [0, 0]:
            height = self._height_slice[1] - self._height_slice[0]
        else:
            height = self._block_sizes[0] * self.input_shape[1]
        return [
            -1,
            height,
            width,
            old_div(self.input_shape[3], (self._block_sizes[0] * self._block_sizes[1])),
        ]

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["block_sizes"] = self._block_sizes
        result["params"]["depth_to_space_type"] = self._depth_to_space_type.value
        if self._width_slice:
            result["params"]["width_slice"] = [int(x) for x in self._width_slice]
        if self._height_slice:
            result["params"]["height_slice"] = [int(x) for x in self._height_slice]
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DEPTH_TO_SPACE
        node.block_sizes.extend(self.block_sizes)
        node.depth_to_space_type = pb_wrapper.DEPTH_TO_SPACE_TYPE_TYPE_TO_PB[self._depth_to_space_type]

        if self._width_slice:
            node.width_slice.start, node.width_slice.end, node.width_slice.stride = (
                self._width_slice[0],
                self._width_slice[1],
                1,
            )

        if self._height_slice:
            node.height_slice.start, node.height_slice.end, node.height_slice.stride = (
                self._height_slice[0],
                self._height_slice[1],
                1,
            )

        return node

    @property
    def block_sizes(self):
        return self._block_sizes

    @block_sizes.setter
    def block_sizes(self, block_sizes):
        self._block_sizes = block_sizes

    @property
    def depth_to_space_type(self):
        return self._depth_to_space_type

    @property
    def width_slice(self):
        return self._width_slice

    @property
    def height_slice(self):
        return self._height_slice

    @depth_to_space_type.setter
    def depth_to_space_type(self, depth_to_space_type):
        self._depth_to_space_type = depth_to_space_type

    def get_reshape_shapes(self):
        return self._first_reshape, self._second_reshape

    @classmethod
    def _validate_input_features_size(cls, layer):
        if layer.input_shape[3] % (layer.block_sizes[0] * layer.block_sizes[1]) != 0:
            raise UnsupportedModelError(
                f"Input features size ({layer.input_shape[3]}) needs to be multiply of "
                f"{(layer.block_sizes[0] * layer.block_sizes[1])} for {layer.full_name_msg}",
            )

    def validate_parsed_reshaped_features(self):
        first_reshape, second_reshape = self.get_reshape_shapes()
        if first_reshape and second_reshape and self.block_sizes[0] == self.block_sizes[1]:
            h, w, c = self.input_shape[1:]
            blocksize = self.block_sizes[0]
            if self.depth_to_space_type == DepthToSpaceType.dcr:
                expected_first_shape = [blocksize, blocksize, c // (blocksize**2), h, w]
            else:  # crd
                expected_first_shape = [c // (blocksize**2), blocksize, blocksize, h, w]
            expected_second_shape = [c // (blocksize**2), h * blocksize, w * blocksize]
            if first_reshape[1:] != expected_first_shape or second_reshape[1:] != expected_second_shape:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} implemented by an unexpected reshape and transpose "
                    f"combination. Expected first reshape was {expected_first_shape} and got "
                    f"{first_reshape}, expected second reshape was {expected_second_shape} and "
                    f"got {second_reshape}",
                )

    def _validate_slice_params(self):
        total_width = self._block_sizes[1] * self.input_shape[2]
        total_height = self._block_sizes[0] * self.input_shape[1]
        if self._width_slice:
            if len(self._width_slice) != 3:
                raise UnsupportedModelError("Slice value must contains 3 parameters [start, end, stride]")

            if self._width_slice[:-1] != [0, 0]:
                if self._width_slice[-1] != 1:
                    raise UnsupportedModelError("slice must be 1 for slice parameters")

                if any(x < 0 for x in self._width_slice):
                    raise UnsupportedModelError("slice params must be non negative")

                if (
                    self._width_slice[0] > total_width
                    or self._width_slice[1] > total_width + 1
                    or self._width_slice[0] >= self._width_slice[1]
                ):
                    raise UnsupportedModelError("Invalid width slice params - out of range")

        if self._height_slice:
            if len(self._height_slice) != 3:
                raise UnsupportedModelError("Slice value must contains 3 parameters [start, end, stride]")

            if self._height_slice[:-1] != [0, 0]:
                if self._height_slice[-1] != 1:
                    raise UnsupportedModelError("slice must be 1 for slice parameters")

                if any(x < 0 for x in self._height_slice):
                    raise UnsupportedModelError("slice params must be non negative")

                if (
                    self._height_slice[0] > total_height
                    or self._height_slice[1] > total_height + 1
                    or self._height_slice[0] >= self._height_slice[1]
                ):
                    raise UnsupportedModelError("Invalid height slice params - out of range")

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "block_size" in hn["params"]:
            layer.block_sizes = [hn["params"]["block_size"], hn["params"]["block_size"]]
        elif "block_sizes" in hn["params"]:
            layer.block_sizes = hn["params"]["block_sizes"]
        if "depth_to_space_type" in hn["params"]:
            layer.depth_to_space_type = DepthToSpaceTypes[hn["params"]["depth_to_space_type"]]
        else:
            layer.depth_to_space_type = DepthToSpaceType.dcr

        if "width_slice" in hn["params"] and any(hn["params"]["width_slice"]):
            layer._width_slice = hn["params"]["width_slice"]
            if len(layer._width_slice) == 2:
                layer._width_slice = [*layer._width_slice, 1]

        if "height_slice" in hn["params"] and any(hn["params"]["height_slice"]):
            layer._height_slice = hn["params"]["height_slice"]
            if len(layer._height_slice) == 2:
                layer._height_slice = [*layer._height_slice, 1]

        cls._validate_input_features_size(layer)
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.block_sizes = pb.block_sizes[:]
        layer.depth_to_space_type = pb_wrapper.DEPTH_TO_SPACE_TYPE_PB_TO_TYPE[pb.depth_to_space_type]
        layer._width_slice = [pb.width_slice.start, pb.width_slice.end, pb.width_slice.stride]
        layer._height_slice = [pb.height_slice.start, pb.height_slice.end, pb.height_slice.stride]
        cls._validate_input_features_size(layer)
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.block_sizes = old_layer.block_sizes.copy()
        layer.depth_to_space_type = old_layer.depth_to_space_type
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
