import copy

from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import InputLayer, OutputLayer


def _get_input_output_data_types(layer):
    precision = layer.precision_config.precision_mode
    if precision is None:
        return None, None
    return precision.input_precision_mode(), precision.output_precision_mode()


def get_subgraphs(hailo_nn):
    subgraphs = {}
    nodes = set()

    io_layers = hailo_nn.get_all_input_layers() + hailo_nn.get_output_layers()

    for node in hailo_nn.stable_toposort():
        if node in io_layers:
            continue

        if node.name in nodes:
            continue

        nn = HailoNN()
        nn.net_params.net_scopes = copy.deepcopy(hailo_nn.net_params.net_scopes)
        preds = list(hailo_nn.predecessors(node))
        succs = list(hailo_nn.successors(node))

        clone_node = copy.deepcopy(node)
        nn.add_node(clone_node)

        input_data_type, output_data_types = _get_input_output_data_types(clone_node)

        for i, pred in enumerate(preds):
            absolute_index = list(hailo_nn.successors(pred)).index(node)
            input_layer = InputLayer()
            input_layer.name = pred.name

            input_layer.index = pred.index
            input_shape = pred.output_shapes[absolute_index]
            input_shape = input_shape if len(input_shape) == 4 else [input_shape[0], 1, 1, input_shape[1]]

            input_layer.input_shapes = [input_shape]
            input_layer.output_shapes = [input_shape]
            input_layer.precision_config.precision_mode = input_data_type
            input_layer.is_real_io = pred.op == LayerType.input_layer

            nn.add_node(input_layer)
            nn.add_edge(input_layer, clone_node)

        for i, succ in enumerate(succs):
            output_layer = OutputLayer()
            output_layer.name = succ.name
            output_layer.index = succ.index
            output_layer.input_shapes = [node.output_shapes[i]]
            output_layer.output_shapes = [node.output_shapes[i]]
            output_layer.precision_config.precision_mode = output_data_types
            output_layer.is_real_io = succ.op == LayerType.output_layer

            nn.add_node(output_layer)
            nn.add_edge(clone_node, output_layer)

        subgraphs[node.name] = nn
        nodes.add(node.name)

    return subgraphs
