from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_CONCAT_AXIS,
    BiasMode,
    ConcatAxis,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError, InvalidInputShape


class HailoConcat(BaseHailoSingleAtomic):
    """
    Represents `concat` layer in the hn

    Args:
        name: layer name
        num_inputs: number of inputs
        axis: defaults to DEFAULT_CONCAT_AXIS - an enum describing the axis on which the layer concatenates.
        logger: the logger for the class
        **kl_kwargs:

    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }
    _hn_type = LayerType.CONCAT
    OP_NAME = "concat_op"

    def __init__(
        self,
        name: str,
        num_inputs: int,
        logger=None,
        axis: ConcatAxis = DEFAULT_CONCAT_AXIS,
        group_sizes=None,
        **kwargs,
    ):
        atomic_concat = ConcatOp(
            f"{name}/{self.OP_NAME}",
            concat_elements=num_inputs,
            logger=logger,
            axis=axis,
            group_sizes=group_sizes,
        )
        super().__init__(name=name, core_op=atomic_concat, logger=logger, **kwargs)

        self.encoding_const = False

    @property
    def spatial_concat(self):
        return self.atomic_op.spatial_h_concat or self.atomic_op.spatial_w_concat

    @property
    def is_precision_transparent(self) -> bool:
        return True

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in_nodes = [layer_flow.add_input() for _ in range(self.atomic_op.num_inputs)]
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.atomic_op)

        for index, node in enumerate(in_nodes):
            layer_flow.add_edge(node, self.atomic_op, DataPath.LAYER_IN, input_index=index)
        layer_flow.add_edge(self.atomic_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def to_hn(self, out_degree=None):
        params = self._hn_element.get("params", dict())
        params["concat_axis"] = self.atomic_op._axis.value
        if self.atomic_op._group_sizes is not None:
            params["group_sizes"] = self.atomic_op._group_sizes
        if "params" not in self._hn_element:
            self._hn_element["params"] = {}
        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element.get("params", dict())
        spatial_w_concat = params.get("spatial_w_concat", False)
        axis = ConcatAxis(params.get("concat_axis", "features"))
        group_sizes = params.get("group_sizes")
        if axis.value != "features" and axis.value != "spatial_h" and group_sizes is not None and len(group_sizes) > 1:
            raise AccelerasValueError(
                'concat layer does not support (asymmetric) group concatenation and axis != "features"; '
                + f"got axis={axis.value} for concat op {lname}.",
            )
        if spatial_w_concat and axis != ConcatAxis.spatial_w:
            raise ValueError(f"In {lname} spatial_w_concat is True but concat axis is {axis.value}")
        layer = cls(
            name=lname,
            num_inputs=len(hn_element["input_shapes"]),
            logger=logger,
            axis=axis,
            group_sizes=group_sizes,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        if self.spatial_concat:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        return EquivClassification(LayerHandlerType.cc_aggregator, is_source=False)

    def verify_layer_inputs_shape(self, input_shapes):
        axis = self.atomic_op._get_concat_axis()
        for idx in range(len(input_shapes[0])):
            if idx != axis and any(input_shapes[0][idx] != in_shape[idx] for in_shape in input_shapes[1:]):
                raise InvalidInputShape(
                    f"Input shapes {input_shapes} doesn't match concat axis: {axis} in {self.full_name}",
                    self.full_name,
                )

    def _validate_zero_points(self, zero_points, params, key):
        pass
