import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_CONCAT_AXIS,
    ConcatAxis,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ConcatLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True
    _IS_RANK3_SUPPORTED = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.concat
        self._input_list = []
        self._number_of_inputs_supported = None
        self._axis = DEFAULT_CONCAT_AXIS

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None, axis=DEFAULT_CONCAT_AXIS, group_sizes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.axis = axis
        layer.group_sizes = group_sizes
        return layer

    def append_to_input_list(self, inp):
        self._input_list.append(inp)

    @property
    def input_list(self):
        return self._input_list

    @input_list.setter
    def input_list(self, input_list):
        self._input_list = input_list

    @property
    def input_height(self):
        if self.is_from_dense():
            return 1
        return self._get_shape_single_dim(self._input_shapes, 1)

    @property
    def input_width(self):
        if self.is_from_dense():
            return 1
        return self._get_shape_single_dim(self._input_shapes, 2)

    @property
    def unsafe_output_shape(self):
        return self._get_output_shape(validate=False)

    @property
    def axis(self):
        return self._axis

    @axis.setter
    def axis(self, axis):
        self._axis = axis

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_CONCAT
        node.kernel_shape.features = self.output_shape[3]
        node.concat_axis = pb_wrapper.CONCAT_AXIS_TYPE_TO_PB[self._axis]
        if self._group_sizes:
            node.group_sizes.extend(self._group_sizes)
        return node

    def _report_spatial_dims_not_equal(self, validate):
        output_shapes = ""
        for in_item in self._input_list:
            layer_output_shape = self.pred_layer_output_shape(in_item, validate)
            output_shapes += f"{in_item.name}: output_shape={layer_output_shape}\n"
        raise UnsupportedModelError(
            f"Unsupported dimensions at {self.full_name_msg} with output shapes\n"
            f"{output_shapes}. The concat layer operates on the feature dim (#3) and all other "
            f"spatial dimensions (#0, #1, #2) must all be equal.",
        )

    def _validate_concat_spatial_dims(self, validate=False):
        spatial_dims_equal = True
        first_input_shape = self.pred_layer_output_shape(self._input_list[0], validate)
        for in_item in self._input_list:
            in_item_output_shape = self.pred_layer_output_shape(in_item, validate)
            if (
                (first_input_shape[0] != in_item_output_shape[0])
                or (self.axis != ConcatAxis.spatial_h and first_input_shape[1] != in_item_output_shape[1])
                or (self.axis != ConcatAxis.spatial_w and first_input_shape[2] != in_item_output_shape[2])
                or (self.axis != ConcatAxis.features and first_input_shape[3] != in_item_output_shape[3])
            ):
                spatial_dims_equal = False
                break

        if not spatial_dims_equal:
            self._report_spatial_dims_not_equal(validate)

    def _validate_concat_from_dense_spatial_dims(self, validate=False):
        spatial_dims_equal = True
        first_input_shape = self.pred_layer_output_shape(self._input_list[0], validate)
        for in_item in self._input_list:
            in_item_output_shape = self.pred_layer_output_shape(in_item, validate)
            if first_input_shape[0] != in_item_output_shape[0]:
                spatial_dims_equal = False
                break

        if not spatial_dims_equal:
            self._report_spatial_dims_not_equal(validate)

    def _calc_output_shape(self):
        if self._axis == ConcatAxis.spatial_w:
            return self._calc_spatial_w_concat()
        elif self._axis == ConcatAxis.spatial_h:
            return self._calc_spatial_h_concat()
        else:
            if self.is_from_dense(validate=False):
                concat_b, concat_h, concat_w = self._calc_output_shape_dense_inputs()
            else:
                concat_b, concat_h, concat_w = self._calc_output_shape_convlike_inputs()
            concat_f = sum([shape[-1] for shape in self.input_shapes])

        return [concat_b, concat_h, concat_w, concat_f]

    def _calc_spatial_w_concat(self):
        first_input_output_shape = self.pred_layer_output_shape(self._input_list[0], True)
        concat_b = first_input_output_shape[0]
        concat_h = first_input_output_shape[1]
        concat_f = first_input_output_shape[3]
        concat_w = sum([self.pred_layer_output_shape(in_item)[2] for in_item in self._input_list])

        return [concat_b, concat_h, concat_w, concat_f]

    def _calc_spatial_h_concat(self):
        first_input_output_shape = self.pred_layer_output_shape(self._input_list[0], True)
        output_shape = first_input_output_shape[:]
        output_shape[1] = sum([self.pred_layer_output_shape(in_item)[1] for in_item in self._input_list])

        return output_shape

    def _calc_output_shape_convlike_inputs(self):
        self._validate_concat_spatial_dims(validate=True)
        first_input_output_shape = self.pred_layer_output_shape(self._input_list[0])
        concat_b = first_input_output_shape[0]
        concat_h = first_input_output_shape[1]
        concat_w = first_input_output_shape[2]
        return concat_b, concat_h, concat_w

    def _calc_output_shape_dense_inputs(self):
        self._validate_concat_from_dense_spatial_dims()
        concat_b = -1
        concat_h = 1
        concat_w = 1
        return concat_b, concat_h, concat_w

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._axis = pb_wrapper.CONCAT_AXIS_PB_TO_TYPE[pb.concat_axis]
        if pb.group_sizes:
            layer.group_sizes = pb.group_sizes
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op == LayerType.matmul:
            layer.dynamic_weights = False
        elif old_layer.op not in [LayerType.rnn, LayerType.lstm]:
            layer.axis = old_layer.axis
            layer.group_sizes = old_layer.group_sizes
        return layer

    def _get_output_shape(self, validate=True, layer_name=None, layer_index=None):
        if len(self._output_shapes) == 0:
            return None

        if validate:
            if self.is_from_dense(validate=False):
                self._validate_concat_from_dense_spatial_dims()
            else:
                self._validate_concat_spatial_dims()
        return self._output_shapes[0]

    def _serialize_shapes_pb(self, shapes_to_serialize, shapes_pb):
        for shape in shapes_to_serialize:
            shape_pb = shapes_pb.add()
            if len(shape) == 2:
                shape_pb.width = 1
                shape_pb.height = 1
                _, shape_pb.features = shape
            else:
                _, shape_pb.height, shape_pb.width, shape_pb.features = shape

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn, validate_params_exist=False)

        legacy_spatial_w = False
        if "params" in hn:
            if "spatial_w_concat" in hn["params"] and hn["params"]["spatial_w_concat"] is True:
                legacy_spatial_w = True
                layer._axis = ConcatAxis.spatial_w
            if "concat_axis" in hn["params"]:
                hn_axis = ConcatAxis(hn["params"]["concat_axis"])
                if legacy_spatial_w and hn_axis != ConcatAxis.spatial_w:
                    raise UnsupportedModelError(
                        f"At {layer.full_name_msg} spatial_w_concat is True but concat axis is {hn_axis.value}",
                    )
                layer.axis = hn_axis
            layer.group_sizes = hn["params"].get("group_sizes")
            if layer.axis == ConcatAxis.spatial_w and layer.group_sizes is not None and len(layer.group_sizes) > 1:
                raise UnsupportedModelError(
                    f"At {layer.full_name_msg} spatial_w_concat is True but group_sizes > 1, which isn't supported",
                )
            if layer.group_sizes is not None:
                group_sizes_sum = sum(layer.group_sizes)
                for input_shape in layer.input_shapes:
                    if input_shape[3] % group_sizes_sum != 0 and layer.axis == ConcatAxis.features:
                        raise UnsupportedModelError(
                            f"At {layer.full_name_msg} input features can not be divided into "
                            f"the given groups sizes {layer.group_sizes}",
                        )
                    if input_shape[1] % group_sizes_sum != 0 and layer.axis == ConcatAxis.spatial_h:
                        raise UnsupportedModelError(
                            f"At {layer.full_name_msg} input height can not be divided into "
                            f"the given groups sizes {layer.group_sizes}",
                        )
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["concat_axis"] = self._axis.value
        if self._group_sizes:
            result["params"]["group_sizes"] = self._group_sizes
        return result

    def get_equalization_handler_type(self, predecessor=None):
        handler_type = (
            LayerHandlerType.cc_aggregator if self.axis == ConcatAxis.features else LayerHandlerType.unsupported
        )
        return EquivClassification(handler_type, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        handler_type = (
            LayerHandlerType.cc_aggregator if self.axis == ConcatAxis.features else LayerHandlerType.unsupported
        )
        return EquivClassification(handler_type, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        handler_type = (
            LayerHandlerType.cc_aggregator if self.axis == ConcatAxis.features else LayerHandlerType.unsupported
        )
        return EquivClassification(handler_type, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    def is_zippable(self, other):
        """Allow zipping two concat layers on feature axis only)"""
        if not (self.axis == other.axis == ConcatAxis.features):
            return False
        if self.group_sizes is not None or other.group_sizes is not None:
            return False
        return super().is_zippable(other)
