from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.depth_to_space_op import DepthToSpaceOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    DepthToSpaceType,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInitializationError


class HailoDepthToSpace(BaseHailoSingleAtomic):
    """Represents `depth_to_space` layer in the hn"""

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.DEPTH_TO_SPACE
    OP_NAME = "depth_to_space_op"

    def __init__(
        self,
        name: str,
        mode: Union[DepthToSpaceType, str],
        block_sizes: tuple = (2, 2),
        logger=None,
        **kwargs,
    ):
        depth_to_space_op = DepthToSpaceOp(
            name=f"{name}/{self.OP_NAME}",
            mode=DepthToSpaceType(mode),
            block_sizes=block_sizes,
            logger=logger,
        )
        super().__init__(name=name, core_op=depth_to_space_op, logger=logger, **kwargs)

        self.encoding_const = False

    @classmethod
    def _validate_input_features_size(cls, block_sizes, input_shape, layer_name):
        """
        validate that the block sizes are compatible with input shape
        """
        blocks_size_mult = np.prod(block_sizes)
        input_features = input_shape[3]
        # assert at least on of the clock sized is at least 2
        if block_sizes[0] == block_sizes[1] and (block_sizes[0] < 2):
            raise AccelerasInitializationError(
                f"At list one of the block sizes must be at => 2  for layer {layer_name}",
            )
        # assert the depth of the input tensor (input_features) is divisible by block_sizes[0] * block_sizes[1].
        if input_features % blocks_size_mult != 0:
            raise AccelerasInitializationError(
                f"Input features size ({input_features}) needs to be multiply of {blocks_size_mult} "
                f"for layer {layer_name}",
            )

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element.get("params", dict())
        if "block_size" in params:
            block_sizes = [params["block_size"], params["block_size"]]
        elif "block_sizes" in params:
            block_sizes = params["block_sizes"]
        else:
            raise AccelerasInitializationError(f"block_size or block_sizes must be in hn of {lname}")
        input_shape = list(hn_element["input_shapes"])[0]
        cls._validate_input_features_size(block_sizes, input_shape, lname)
        layer = cls(
            name=lname,
            mode=params["depth_to_space_type"],
            block_sizes=block_sizes,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def is_precision_transparent(self) -> bool:
        return True

    @property
    def atomic_ops(self):
        return [self.atomic_op]
