from typing import TYPE_CHECKING, Any, Generator, Optional, Set

import networkx as nx
import numpy as np
import tensorflow as tf
from pydantic.v1 import BaseModel, Field

if TYPE_CHECKING:
    from hailo_model_optimization.acceleras.model.hailo_model import HailoModel

from contextlib import contextmanager
from functools import wraps

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DistributionStrategy,
    GPUAvailabilityMode,
    GPUInfo,
)


class DistContextInfo(BaseModel):
    gpu_info: GPUInfo = Field(GPUInfo())
    dist_strategy: DistributionStrategy = Field(DistributionStrategy.SINGLE)
    tf_strategy: Optional[Any]
    call_counter: int = Field(0, ge=0)

    class Config:
        arbitrary_types_allowed = True
        validate_assignment = True


def get_strategy(
    gpu_info: GPUInfo,
    gpu_policy: DistributionStrategy,
    default_gpu_policy: DistributionStrategy,
    *,
    supported_gpu_info: Optional[Set[DistributionStrategy]] = None,
    current_context_info: Optional[DistContextInfo] = None,
) -> DistContextInfo:
    """Transforms FeaturePolicy to DistributionStrategy"""

    multi_gpu_strategies = {
        DistributionStrategy.DATA_P,
        DistributionStrategy.MODEL_P,
    }

    ## Handle Default vs Set

    # Check if the conditions are met
    if gpu_policy in multi_gpu_strategies:
        if gpu_info.gpu_availability is GPUAvailabilityMode.NOT_AVAILABLE or gpu_info.num_gpus < 2:
            raise ValueError(f"Can't use distributed strategy: {gpu_info.gpu_availability} without GPU")
        else:
            res = gpu_policy

    # AUTO will choose default
    elif gpu_policy is DistributionStrategy.AUTO:
        if (
            gpu_info.gpu_availability is GPUAvailabilityMode.NOT_IN_USE
            and gpu_info.num_gpus >= 2
            and default_gpu_policy in multi_gpu_strategies
        ):
            res = default_gpu_policy
        else:
            res = DistributionStrategy.SINGLE

    # SINGLE GPU
    else:
        res = DistributionStrategy.SINGLE

    res = res if supported_gpu_info and res in supported_gpu_info else DistributionStrategy.SINGLE

    # If the Strategy didn't change and have already a context use it then keep using it.
    if current_context_info and current_context_info.dist_strategy == res and current_context_info.tf_strategy:
        current_context_info.gpu_info = gpu_info
        context_info = current_context_info
    else:
        context_info = DistContextInfo(gpu_info=gpu_info, dist_strategy=res)
    return context_info


@contextmanager
def gpu_distributed_context(context_info: DistContextInfo) -> Generator[DistContextInfo, None, None]:
    """Context where the strategy is applied"""
    if context_info.dist_strategy is DistributionStrategy.DATA_P:
        if context_info.tf_strategy is None:
            strategy = tf.distribute.MirroredStrategy()
            context_info.tf_strategy = strategy

        with context_info.tf_strategy.scope():
            yield context_info
    else:
        yield context_info
    context_info.call_counter += 1


@contextmanager
def tf_device_wrapper(device: int):
    """Contex Manager to be Use on Build or Call methods of a HailoLayer"""
    if device >= 0:
        with tf.device(f"/gpu:{device}"):
            yield
    elif device == -1:
        yield
    else:
        ValueError("Unknown configuration for GPU distribution")


def manage_layers_devices(func):
    """Decorator to be uso on the build model of Hailo Model"""

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        self: "HailoModel"
        if self.use_external_gpu_policy:
            pass
        elif self.dist_info.dist_strategy == DistributionStrategy.MODEL_P:
            ParallelModel(self, self.dist_info.gpu_info.num_gpus).run()
        else:
            ParallelModel(self, 1).run()

        res = func(self, *args, **kwargs)
        return res

    return wrapper


class ParallelModel:
    """
    Class to In charge of allocating layers to GPU
    This class will modify the HailoModel that is given to run.
    """

    MIN_DIST_GPU = 2

    def __init__(self, model: "HailoModel", num_partitions: int):
        self._model = model
        self.num_partitions = num_partitions
        self.graph = nx.DiGraph()

    def run(self):
        self.load_graph(self._model)
        self.greedy_find(self.num_partitions)
        self.copy_info()

    def load_graph(self, model: "HailoModel"):
        graph = nx.DiGraph()
        number_weights = 0
        for layer_name in model.flow.toposort():
            layer = model.layers[layer_name]
            weights = np.prod(layer.kernel.shape) if hasattr(layer, "kernel") else 1
            number_weights += weights
            graph.add_node(layer_name, weights=weights)

        for u_node, v_node in model.flow.edges():
            layer = model.layers[v_node]
            index = model.flow.get_edge_data(u_node, v_node)["input_index"]
            num_values = np.prod(layer.input_shapes[index][1:])
            graph.add_edge(u_node, v_node, num_values=num_values)
        self.number_weights = number_weights
        self.graph = graph

    def greedy_find(self, free_gpu: int):
        if free_gpu < self.MIN_DIST_GPU:
            nx.set_node_attributes(self.graph, -1, "device")
        else:
            self._find_partitions(free_gpu)

    def _find_partitions(self, n_partitions):
        partition_weights = (self.number_weights / n_partitions) + 1
        current = 0
        for layer in nx.lexicographical_topological_sort(self.graph):
            current += self.graph.nodes[layer]["weights"]
            self.graph.nodes[layer]["device"] = int(np.floor(current / partition_weights))

    def copy_info(self):
        for node, data in self.graph.nodes(data=True):
            self._model.layers[node].gpu_index = data["device"]
