from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.space_to_depth_op import SpaceToDepthOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PrecisionMode,
    SpaceToDepthType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasInitializationError,
    InvalidInputShape,
)


class HailoSpaceToDepth(BaseHailoSingleAtomic):
    """Represents `space_to_depth` layer in the hn"""

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.SPACE_TO_DEPTH
    OP_NAME = "space_to_depth_op"

    def __init__(
        self,
        name: str,
        block_sizes: tuple = (2, 2),
        space_to_depth_type: Union[str, SpaceToDepthType] = "classic_dcr",
        spatial_flatten_output: bool = False,
        logger=None,
        **kwargs,
    ):
        space_to_depth_op = SpaceToDepthOp(
            f"{name}/{self.OP_NAME}",
            block_sizes=block_sizes,
            space_to_depth_type=SpaceToDepthType(space_to_depth_type),
            spatial_flatten_output=spatial_flatten_output,
            logger=logger,
        )
        super().__init__(name=name, core_op=space_to_depth_op, logger=logger, **kwargs)

        self.encoding_const = False

    @classmethod
    def _validate_input(cls, block_sizes, input_shape, space_to_depth_type, layer_name):
        # we support only block_sizes=(2,2) or (16,16)
        height = input_shape[1]
        width = input_shape[2]
        block_height = block_sizes[0]
        block_width = block_sizes[1]
        if SpaceToDepthType(space_to_depth_type) not in tuple(SpaceToDepthType):
            raise AccelerasInitializationError(f"{layer_name}: None-Existent space_to_depth_type")
        if SpaceToDepthType(space_to_depth_type) == SpaceToDepthType.SERIAL and tuple(block_sizes) != (16, 16):
            raise AccelerasInitializationError(
                f"{layer_name}: the block sizes are: ({block_height}), ({block_width}). should be (16, 16) in SERIAL type",
            )
        if height % block_height != 0:
            raise AccelerasInitializationError(
                f"{layer_name}: Input height size ({height}) should  be multiply of {block_height}",
            )
        if width % block_width != 0:
            raise AccelerasInitializationError(
                f"{layer_name}: Input width size ({width}) should be multiply of {block_width}",
            )

    @property
    def is_precision_transparent(self) -> bool:
        return True

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element.get("params", dict())
        if "block_size" in params:
            block_sizes = [params["block_size"], params["block_size"]]
        elif "block_sizes" in params:
            block_sizes = params["block_sizes"]
        else:
            raise AccelerasImplementationError("Missing block size field")

        space_to_depth_type = params.get("space_to_depth_type", "classic_dcr")
        input_shape = list(hn_element["input_shapes"])[0]
        spatial_flatten_output = params.get("spatial_flatten_output", False)
        cls._validate_input(block_sizes, input_shape, space_to_depth_type, lname)
        layer = cls(
            name=lname,
            block_sizes=block_sizes,
            space_to_depth_type=space_to_depth_type,
            spatial_flatten_output=spatial_flatten_output,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def verify_layer_inputs_shape(self, input_shapes):
        block_size = self.atomic_op.block_sizes
        if block_size[0] * block_size[1] > input_shapes[0][1] * input_shapes[0][2]:
            raise InvalidInputShape(
                f"Input shapes {input_shapes} doesn't match block size {block_size} in {self.full_name}",
                self.full_name,
            )
