from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.format_conversion_op import (
    SpatialReshapeHeightFeaturesOp,
    SpatialReshapeOp,
    TransposeWidthFeaturesOp,
    format_conversion_factory,
)
from hailo_model_optimization.acceleras.atomic_ops.reshape_op import ReshapeOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    FormatConversionType,
    LayerHandlerType,
    LayerType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError

EQUALIZTION_TRANSPARENT_CONVERSION = (
    FormatConversionType.spatial_reshape,
    FormatConversionType.spatial_flatten,  # backwards compatibility
    FormatConversionType.spatial_expand,  # backwards compatibility
    FormatConversionType.merge_windowed_attention,
    FormatConversionType.split_windowed_attention,
    FormatConversionType.transpose_height_width,
)


class HailoFormatConversion(BaseHailoSingleAtomic):
    """
    Non-Arithmetic Layer for changing formats

    This layer is in charge of providing the Utility of changing the format of the inputs,
    it is dependent on AtomicOp to do this changes.
    Supported Format Conversions:
    Bayer filter:
        Use when the input has 2 green pixels for every four
        [r , g]
        [g, b ]   ==> [r,g,b]
        and the different permutations of this position :rggb,  bggr , grbg, gbrg
    Features to Width features :
        Reshapes from features to width

    Flat to Frames:
        Reshapes from features to weight

    Spatial flatten:
        Reshapes : [-1 , h , w, f ] -> [ -1 , 1, h * w, f]

    Twelve To Eight Bit:
        Reduce numbers of bits.

    Transpose Width Features:
        Transpose Features axis with the Width axis

    Attributes
        conversion : Type of format conversion that this CompositeOp will perform.

    """

    _hn_type = LayerType.FORMAT_CONVERSION
    OP_NAME = "format_conversion_op"

    def __init__(
        self,
        name: str,
        conversion: Union[FormatConversionType, str],
        shape=None,
        groups=None,
        input_windows=None,
        output_windows=None,
        spatial_reshape_sizes=None,
        in_emulation_graph=True,
        logger=None,
        **kwargs,
    ):
        conversion = FormatConversionType(conversion)
        op_name = f"{name}/{self.OP_NAME}"
        if conversion in [
            FormatConversionType.features_to_width_features,
            FormatConversionType.flat_to_frames,
            FormatConversionType.frames_to_flat,
        ]:
            # need the output shape (parse takes care of the shape)
            if shape is None:
                raise AccelerasValueError(f"Shape is required for conversion type {conversion.value}")
            format_conversion_op = ReshapeOp(
                op_name,
                reshape_size=shape,
                logger=logger,
            )
        elif conversion == FormatConversionType.transpose_width_features:
            format_conversion_op = TransposeWidthFeaturesOp(
                op_name,
                groups=groups,
                logger=logger,
            )
        elif conversion == FormatConversionType.spatial_reshape:
            format_conversion_op = SpatialReshapeOp(
                op_name,
                spatial_reshape_sizes=spatial_reshape_sizes,
                input_windows=input_windows,
                output_windows=output_windows,
                logger=logger,
            )
        elif conversion == FormatConversionType.reshape_height_features:
            format_conversion_op = SpatialReshapeHeightFeaturesOp(
                op_name,
                spatial_reshape_sizes=spatial_reshape_sizes,
                input_windows=input_windows,
                output_windows=output_windows,
                logger=logger,
            )
        else:
            op_class = format_conversion_factory(conversion, in_emulation_graph)
            format_conversion_op = op_class(op_name, logger=logger)
        super().__init__(name=name, core_op=format_conversion_op, logger=logger, **kwargs)
        self.conversion = conversion
        self._in_emulation_graph = in_emulation_graph

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        conversion = FormatConversionType(hn_element["params"]["conversion_type"])
        if conversion in [
            FormatConversionType.spatial_expand,
            FormatConversionType.spatial_flatten,
        ]:  # for backward compatibility
            conversion = FormatConversionType.spatial_reshape
        shape = hn_element["output_shapes"][0]
        groups = None
        spatial_reshape_sizes = None
        in_emulation_graph = hn_element.get("in_emulation_graph", True)
        if "groups" in hn_element["params"]:
            groups = hn_element["params"]["groups"]
        if "expand_spatial_sizes" in hn_element["params"]:  # for backward compatibility
            spatial_reshape_sizes = hn_element["params"]["expand_spatial_sizes"]
        elif "spatial_reshape_sizes" in hn_element["params"]:
            spatial_reshape_sizes = hn_element["params"]["spatial_reshape_sizes"]
        if (
            conversion == FormatConversionType.spatial_reshape and spatial_reshape_sizes is None
        ):  # hn does not contain spatial_reshape_sizes, get it from output_shapes
            if "output_shapes" in hn_element:
                spatial_reshape_sizes = hn_element["output_shapes"][0][1:3]
            else:
                raise AccelerasValueError(
                    f"either spatial_reshape_sizes or output_shapes are required for conversion type {conversion.value}"
                )
        input_windows = hn_element.get("params", {}).get("input_windows", [1, 1, 1])
        output_windows = hn_element.get("params", {}).get("output_windows", [1, 1, 1])

        layer = cls(
            name=lname,
            conversion=conversion,
            shape=shape,
            groups=groups,
            input_windows=input_windows,
            output_windows=output_windows,
            spatial_reshape_sizes=spatial_reshape_sizes,
            in_emulation_graph=in_emulation_graph,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        if self.conversion in EQUALIZTION_TRANSPARENT_CONVERSION:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)
