from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class MergedLayer(Layer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.merged_layer
        self._mini_graph = {}
        self._number_of_inputs_supported = 2
        self._ew_connections = []
        self._is_concat_first = False
        self._sub_layres = []

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_MERGE
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._mini_graph = pb.mini_graph
        for layer_pb in pb.mini_graph.nodes:
            layer_class = pb_wrapper.PB_TYPE_TO_CLASS[layer_pb.type]
            if layer_class is None:
                continue
            layer_parsed = layer_class.from_pb(layer_pb, pb_wrapper)
            layer._sub_layres.append(layer_parsed)
        layer._is_concat_first = (
            pb.mini_graph.nodes[0].type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_CONCAT
        )
        return layer

    @classmethod
    def from_hn(cls, hn):
        raise UnsupportedModelError("Merged layer is not allowed from hn")

    @property
    def mini_graph(self):
        return self._mini_graph

    @property
    def sub_layers(self):
        return self._sub_layres

    def _calc_output_shape(self):
        output_node = self.mini_graph.nodes[-1]
        output_shapes = self._parse_shapes_from_pb(output_node.output_shapes)
        return output_shapes[0]

    def ew_add_connections(self):
        return self._ew_connections

    @property
    def ew_add_enabled(self):
        return len(self._ew_connections) > 0

    def add_ew_connection(self, other_layer):
        self._ew_connections.append(other_layer)
        for layer_pb, sub_layer in zip(self._mini_graph.nodes, self._sub_layres):
            if layer_pb.ew_add:
                sub_layer.add_ew_connection(other_layer)

    def clear_ew_connections(self):
        self._ew_connections = []
        for layer in self._sub_layres:
            layer.clear_ew_connections()

    @property
    def is_concat_first(self):
        return self._is_concat_first

    @property
    def input_features(self):
        return self._get_shape_single_dim(self._input_shapes, 3, validate=(not self.ew_add_enabled))

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def macs(self):
        return sum([sub_layer.macs for sub_layer in self._sub_layres])

    @property
    def ops(self):
        return sum([sub_layer.ops for sub_layer in self._sub_layres])

    @property
    def weights(self):
        return sum([sub_layer.weights for sub_layer in self._sub_layres])
