"""
This module contains block heuristic for different algorithms
"""

from bisect import bisect
from itertools import chain
from typing import Callable, Dict, List, Tuple

import networkx as nx
import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_acceleras_layer import BaseAccelerasLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow


def naive_blocks(model: HailoModel, block_end_filter: Callable[[BaseAccelerasLayer], bool]) -> List[ModelFlow]:
    """
    Naive block heursitcs.
    Assumes single output for each block, output blocks are modifiable blocks.
    """
    blocks_start = [node for node, degree in model.flow.in_degree if degree == 0]
    blocks: Dict[str, ModelFlow] = dict()
    handled_block_start_layers = set()
    model_output_nodes = model.flow.output_nodes
    # Create blocks, iterate the graph in BFS
    while blocks_start:
        first_layer = blocks_start.pop(0)
        handled_block_start_layers.add(first_layer)
        nodes = set()
        new_block_layers = [first_layer]
        output_layers = []
        while new_block_layers:
            lname = new_block_layers.pop(0)
            if lname in nodes:
                continue
            nodes.add(lname)
            if block_end_filter(model.layers[lname]):
                for succ in model.flow.successors_sorted(lname):
                    if (succ not in handled_block_start_layers) and (succ not in model_output_nodes):
                        blocks_start.append(succ)
                output_layers.append(lname)
            else:
                new_block_layers.extend(model.flow.successors_sorted(lname))
        # Intersect blocks based on output nodes (mostly for conv&add)
        if len(nodes) == 1:
            current_block = model.flow.subgraph(nodes)
        else:
            edges = set()
            for node in nodes:
                if block_end_filter(model.layers[node]):
                    continue
                currnet_edges = model.flow.edges(node)
                edges.update(currnet_edges)
            current_block = model.flow.edge_subgraph(edges)

        current_block.set_output_order(output_layers)
        block_output_nodes = sorted(
            [
                node
                for node in current_block.nodes
                if (model.flow.out_degree(node) > current_block.out_degree(node))
                or model.flow.out_degree(node) == 0
                or block_end_filter(model.layers[node])
            ],
        )
        for output_node in block_output_nodes:
            sublock_preds = set(current_block.ancestors(output_node))
            sublock_preds.add(output_node)
            current_sublock = current_block.subgraph(sublock_preds)
            block = blocks.get(output_node, current_sublock)
            blocks[output_node] = model.flow.subgraph(current_sublock.nodes | block.nodes)
            blocks[output_node] = blocks[output_node].copy()
            blocks[output_node].set_output_order([output_node])

    # Ensure that the block flow is similar to a model flow (regarding inputs and outputs)
    blocks = {lname: _finalize_block_as_model_flow(model, block, True) for lname, block in blocks.items()}

    layers_order = list(model.flow.toposort())
    sorted_blocks = sorted(blocks.items(), key=lambda x: layers_order.index(x[0]))

    remapped_keys = list(map(lambda x: x[1], sorted_blocks))
    return remapped_keys


def get_communities(model: HailoModel, data_weight=True, resolution=2) -> List[ModelFlow]:
    if data_weight:
        weight_attr = _add_data_weight_attribute(model)
    else:
        weight_attr = None

    intra_nodes = model.flow.nodes - set(model.flow.output_nodes) - set(model.flow.input_nodes)
    sub_flow = model.flow.subgraph(intra_nodes)

    if len(sub_flow.nodes) == 1:
        # The function greedy_modularity_communities fails't work for one node.
        communities = [frozenset(sub_flow.nodes)]
    else:
        communities = nx.community.greedy_modularity_communities(sub_flow, weight=weight_attr, resolution=resolution)
    return communities


def block_communities(model: HailoModel, data_weight=True, resolution=2) -> List[ModelFlow]:
    """
    Modularity based commuties partitioning of the HailoModel.
    if data_weight is true, data_weight (pixel count) attribute is added to the graph edges
    # TODO: add resolution, cutoff, and best_n as configurable arguments
    """
    communities = get_communities(model, data_weight=data_weight, resolution=resolution)
    blocks: List[ModelFlow] = []
    for community in communities:
        block: ModelFlow = model.flow.subgraph(community).copy()
        block = _finalize_block_as_model_flow(model, block)
        blocks.append(block)

    toposort_order = [lname for lname in model.flow.toposort()]

    def get_block_key(block: ModelFlow):
        inputs_index = dict()
        for inp_node in block.input_nodes:
            for real_input in block.successors_sorted(inp_node):
                curr_ind = toposort_order.index(real_input)
                inputs_index[real_input] = curr_ind
        return max(inputs_index.values())

    blocks_with_keys = [(get_block_key(block), block) for block in blocks]
    sorted_blocks_with_keys = sorted(blocks_with_keys, key=lambda x: x[0])

    sorted_blocks = [block[1] for block in sorted_blocks_with_keys]
    sorted_blocks_keys = [block[0] for block in sorted_blocks_with_keys]
    sorted_blocks = _resolve_invalid_blocks_order(model, sorted_blocks, sorted_blocks_keys, get_block_key)

    return sorted_blocks


def block_communities_dedicated_graph(model: HailoModel, data_weight=True) -> List[ModelFlow]:
    """
    Modularity based commuties partitioning of the HailoModel.
    if data_weight is true, data_weight (pixel count) attribute is added to the graph edges
    This function split each node into the graph into 2 for easier split by the community algorithms.
    # TODO: add resolution, cutoff, and best_n as configurable arguments
    # TODO: consider setting higher weight for edges between real nodes (to encourage split between input / output of layer)
    """
    blocks: List[ModelFlow] = []
    flow = differentiable_sub_graphs(model.layers, model.flow)[0]
    graph1, weight_attr = build_graph_greedy_modularity(model.layers, flow, data_weight)
    communities = nx.community.greedy_modularity_communities(graph1, weight=weight_attr)
    new_communities = []
    for community in communities:
        current_community = set()
        for node in community:
            node_elements = node.split("/")
            layer_name = "/".join(node_elements[:-1])
            inout = node_elements[-1]
            if inout == "in":
                current_community.add(layer_name)
        new_communities.append(current_community)
    for community in new_communities:
        block: ModelFlow = model.flow.subgraph(community).copy()
        block = _finalize_block_as_model_flow(model, block)
        blocks.append(block)

    toposort_order = [lname for lname in model.flow.toposort()]

    def get_block_key(block: ModelFlow):
        inputs_index = dict()
        for inp_node in block.input_nodes:
            for real_input in block.successors_sorted(inp_node):
                curr_ind = toposort_order.index(real_input)
                inputs_index[real_input] = curr_ind
        return max(inputs_index.values())

    blocks_with_keys = [(get_block_key(block), block) for block in blocks]
    sorted_blocks_with_keys = sorted(blocks_with_keys, key=lambda x: x[0])

    sorted_blocks = [block[1] for block in sorted_blocks_with_keys]
    sorted_blocks_keys = [block[0] for block in sorted_blocks_with_keys]
    sorted_blocks = _resolve_invalid_blocks_order(model, sorted_blocks, sorted_blocks_keys, get_block_key)

    return sorted_blocks


def _resolve_invalid_blocks_order(
    model: HailoModel,
    sorted_blocks: List[ModelFlow],
    sorted_blocks_keys: List[int],
    block_key_cb,
):
    existing_inputs = set(model.flow.input_nodes)
    i = 0
    while i < len(sorted_blocks):
        block = sorted_blocks[i]
        if set(block.input_nodes).issubset(existing_inputs):
            existing_inputs |= set(block.output_nodes)
            i += 1
        else:
            missing_inputs = set(block.input_nodes) - existing_inputs
            new_loc = set()
            for remaining_block in sorted_blocks[i:]:
                required_nodes = set(remaining_block.output_nodes) & missing_inputs
                if required_nodes:
                    pop_index = sorted_blocks.index(remaining_block)
                    sorted_blocks.pop(pop_index)
                    sorted_blocks_keys.pop(pop_index)
                    new_block_inputs = set()
                    for input_node in remaining_block.input_nodes:
                        for req_node in required_nodes:
                            if nx.has_path(remaining_block, input_node, req_node):
                                new_block_inputs.add(input_node)
                    other_block_inputs = set(remaining_block.input_nodes) - new_block_inputs
                    other_block_nodes = (
                        set().union(*[remaining_block.descendants(node) for node in other_block_inputs])
                        | other_block_inputs
                    )
                    new_node_blocks = set(remaining_block.nodes) - other_block_nodes
                    sub_block1 = remaining_block.subgraph(new_node_blocks).copy()

                    sub_block1 = _finalize_block_as_model_flow(model, sub_block1)

                    block1_index = block_key_cb(sub_block1)
                    j1 = bisect(sorted_blocks_keys[:i], block1_index)
                    sorted_blocks.insert(j1, sub_block1)
                    sorted_blocks_keys.insert(j1, block1_index)

                    new_loc.add(j1)
                    if len(other_block_nodes) == 0:
                        # This case is simply sort fix, why did it happen?
                        continue
                    sub_block2 = remaining_block.subgraph(other_block_nodes).copy()
                    sub_block2 = _finalize_block_as_model_flow(model, sub_block2)

                    block2_index = block_key_cb(sub_block2)
                    j2 = bisect(sorted_blocks_keys[i:], block2_index)
                    sorted_blocks.insert(j2 + i, sub_block2)
                    sorted_blocks_keys.insert(j2 + i, block2_index)
            i = min(new_loc)
            existing_inputs = set().union(*[block.output_nodes for block in sorted_blocks[:i]])
            existing_inputs |= set(model.flow.input_nodes)
    return sorted_blocks


def _add_data_weight_attribute(model: HailoModel, attr_name="data_weight"):
    data_shape = dict()
    for u, v in model.flow.edges:
        out_index = model.flow.get_edge_output_index(u, v)
        out_index = model.layers[u].resolve_output_index(out_index)
        out_shape = model.layers[u].output_shapes[out_index][1:]
        data_weight = np.prod(out_shape)
        data_shape[(u, v)] = data_weight / (model.flow.out_degree[u] / model.layers[u].num_outputs)
    data_shape = {k: {attr_name: (v)} for k, v in data_shape.items()}
    nx.set_edge_attributes(model.flow, data_shape)
    return attr_name


def _get_input_layers(model, block):
    input_layers = []
    for node in block.nodes:
        is_input = False
        if block.nodes[node].get("is_input", False):
            is_input = True
        elif not block.nodes[node].get("is_output", False):
            if model.flow.in_degree(node) > block.in_degree(node):
                is_input = True
            elif model.flow.in_degree(node) == 0:
                # Const input
                is_input = True
        if is_input:
            input_layers.append(node)
    return sorted(input_layers)


def _get_output_layers(model, block):
    output_layers = []
    for node in block.nodes:
        is_output = False
        if block.nodes[node].get("is_output", False):
            is_output = True
        elif not block.nodes[node].get("is_input", False):
            if model.flow.out_degree(node) > block.out_degree(node):
                is_output = True
            elif model.flow.out_degree(node) == 0:
                # Const output (?)
                is_output = True
        if is_output:
            output_layers.append(node)
    return sorted(output_layers)


def _finalize_block_as_model_flow(model: HailoModel, block: ModelFlow, existing_outputs=False) -> ModelFlow:
    """
    Adds input and output nodes to the block, and set the nodes in order for each block (aka finalize).

    Args:
        model (HailoModel): A (block) model that contains the block as a sub-model.
        block (ModelFlow): A sub-model in model.
        existing_outputs (bool, optional): Defaults to False.

    Returns:
        block (ModelFlow): A block with all nodes sorted, including input and output nodes (i.e.,a finalized block).

    """
    # Finalize the input nodes
    input_layers = _get_input_layers(model, block)
    for inp_node in input_layers:
        if block.nodes[inp_node].get("is_input", False):
            continue
        for pred in model.flow.predecessors_sorted(inp_node):
            if pred in block.nodes:
                continue
            inp_ind = model.flow.get_edge_input_index(pred, inp_node)
            out_ind = model.flow.get_edge_output_index(pred, inp_node)
            data_output_index = model.layers[pred].resolve_output_index(out_ind)
            if pred in model.flow.input_nodes:
                inp_node_name = pred
                is_input = True
            elif model.flow.in_degree[pred] == 0:
                # For const input
                inp_node_name = pred
                is_input = False
            else:
                inp_node_name = f"{pred}_output_{data_output_index}"
                is_input = True
            block.add_node(inp_node_name, is_input=is_input)
            block.add_edge(inp_node_name, inp_node, input_index=inp_ind, output_index=0)

    # Finalize the output nodes
    if existing_outputs:
        output_layers = block.output_layer_order
    else:
        output_layers = _get_output_layers(model, block)
    for out_node in output_layers:
        if block.nodes[out_node].get("is_output", False):
            continue
        for succ in model.flow.successors_sorted(out_node):
            if succ in block.nodes:
                continue
            inp_ind = model.flow.get_edge_input_index(out_node, succ)
            out_ind = model.flow.get_edge_output_index(out_node, succ)
            data_output_index = model.layers[out_node].resolve_output_index(out_ind)
            out_node_name = f"{out_node}_output_{data_output_index}"
            block.add_node(out_node_name, is_output=True)
            block.add_edge(out_node, out_node_name, input_index=0, output_index=out_ind)

    block.set_output_order([block.predecessors_sorted(i)[0] for i in _get_output_layers(model, block)])
    return block


def single_center_cut(model: HailoModel, flow: ModelFlow) -> Tuple[ModelFlow, ModelFlow]:
    inp_nodes = set(flow.input_nodes)
    out_nodes = set(flow.output_nodes)
    real_nodes = set(flow.nodes) - inp_nodes - out_nodes
    nodes_count = len(flow.nodes) - len(inp_nodes) - len(out_nodes)

    cut_nodes = set()
    for node in flow.toposort():
        if (node not in inp_nodes) and (node not in out_nodes):
            cut_nodes.add(node)
        if len(cut_nodes) >= (nodes_count / 2):
            break
    # remove const nodes from cut
    const_nodes = [node for node in flow.nodes if isinstance(model.layers[node], HailoConst)]
    for node in const_nodes:
        succes = flow.successors_sorted(node)
        succ_in_cut = [succ in cut_nodes for succ in succes]
        if not np.all(succ_in_cut):
            cut_nodes.remove(node)
    # cut the nodes from the graph
    block1 = flow.subgraph(cut_nodes).copy()
    block2 = flow.subgraph(real_nodes - cut_nodes).copy()
    block1 = _finalize_block_as_model_flow(model, block1)
    block2 = _finalize_block_as_model_flow(model, block2)
    sorted_output = sorted(block1.output_nodes, key=lambda x: block2.input_nodes.index(x))
    output_order = list(chain(*[block1.predecessors(outnode) for outnode in sorted_output]))
    block1.set_output_order(output_order)
    return block1, block2


def build_graph_greedy_modularity(
    layers: Dict[str, BaseAccelerasLayer],
    flow: ModelFlow,
    data_weight=False,
    attr_name="data_weight",
):
    graph = nx.DiGraph()
    for edge in flow.edges:
        u_node, v_node = edge
        output_index = flow.edges[edge]["output_index"]
        num_outputs = layers[u_node].num_outputs
        real_output_index = output_index % num_outputs
        new_u_node = f"{u_node}/out_{real_output_index}"
        new_v_node = f"{v_node}/in"
        out_shape = layers[u_node].output_shapes[real_output_index][1:]
        data_weight = np.prod(out_shape) * flow.out_degree[u_node]
        graph.add_edge(new_u_node, new_v_node, **{attr_name: data_weight})
    for node in flow.nodes:
        for i in range(layers[node].num_outputs):
            new_v_node = f"{node}/out_{i}"
            new_u_node = f"{node}/in"
            out_shape = layers[node].output_shapes[i][1:]
            data_weight = np.prod(out_shape)
            graph.add_edge(new_u_node, new_v_node, **{attr_name: data_weight})
    return graph, attr_name


def differentiable_sub_graphs(layers: Dict[str, BaseAccelerasLayer], flow: ModelFlow):
    sub_graphs = []
    blocking_nodes = set()
    for layer in flow.toposort():
        acceleras_layer = layers[layer]
        if not acceleras_layer.is_differentiable():
            blocking_nodes.add(layer)
            bloop = flow.descendants(layer)
            blocking_nodes.add(layer)
            blocking_nodes.update(set(bloop))
    # Removing blocking nodes and input nodes (to prevent community split in the input layers)
    nodes = flow.nodes - blocking_nodes - set(flow.input_nodes)
    sub_graph1 = flow.subgraph(nodes).copy()
    sub_graphs.append(sub_graph1)
    sub_graph2 = flow.subgraph(blocking_nodes).copy()
    sub_graphs.append(sub_graph2)
    return sub_graphs
