from collections import defaultdict
from dataclasses import dataclass
from functools import wraps
from typing import Dict, List

import torch.nn as nn
from torch import Tensor

from hailo_model_optimization.acceleras.model.hailo_model.model_flow import (
    ModelFlow,
    StaticMappings,
    create_static_mapping,
)
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    Encoding,
    ModelOutput,
    TensorDict,
    TensorList,
)
from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule


@dataclass
class InferResults:
    tensor: List[Tensor] = None
    encoding: List[Encoding] = None
    call_count: int = 0


def skip_if_checks_disabled(func):
    """Decorator that skips method execution if ENABLE_CHECKS=False."""

    @wraps(func)  # This preserves metadata of `func`.
    def wrapper(self: "ModelChecks", *args, **kwargs):
        if not self.enable_checks:
            # Skip execution if checks are disabled
            return
        return func(self, *args, **kwargs)

    return wrapper


class ModelChecks:
    def __init__(self, enable_checks=False):
        self.enable_checks = enable_checks

    @skip_if_checks_disabled
    def check_inputs_consistency(self, static_mapping: StaticMappings, inputs: TensorDict):
        """Real check logic here."""
        if inputs.keys() != set(static_mapping.data_sources):
            raise ValueError("Input nodes are missing from the inputs")


class InferUtils:
    @staticmethod
    def create_inference_results(
        static_mapping: StaticMappings,
        inputs: TensorDict,
    ) -> Dict[str, InferResults]:
        inference_results = defaultdict(InferResults)
        if inputs is not None:
            for name in static_mapping.data_sources:
                inference_results[name].tensor = [inputs[name]] if isinstance(inputs, Dict) else [inputs]
            for name in static_mapping.constant_sources:
                inference_results[name].tensor = []

        for lname, times in static_mapping.out_degree.items():
            inference_results[lname].call_count = times
        return inference_results

    @staticmethod
    def order_results(output_mapping: List[str], inference_results: Dict[str, InferResults]):
        results = []
        for succ in output_mapping:
            t1 = inference_results[succ].tensor
            results.append(t1[0] if len(t1) == 1 else t1)
        results = results if len(results) > 1 else results[0]
        return results

    @staticmethod
    def get_layers_inputs(
        predecessor_map: Dict[str, Dict[str, Dict[str, int]]],
        lname,
        inference_results: Dict[str, InferResults],
        output_mapping: List[str],
    ) -> List[Tensor]:
        inputs = []
        for pred, pred_info in predecessor_map[lname].items():
            pred_data = inference_results[pred]
            pred_data.call_count -= 1
            inputs.append(pred_data.tensor[pred_info["output_index"] % len(pred_data.tensor)])
            if pred_data.call_count == 0 and pred not in output_mapping:
                # Remove the tensor from the memory
                del inference_results[pred].tensor
        return inputs

    @staticmethod
    def get_encoding_inputs(
        predecessor_map: Dict[str, Dict[str, Dict[str, int]]],
        lname,
        inference_results: Dict[str, InferResults],
    ) -> List[Encoding]:
        inputs = []
        for pred, pred_info in predecessor_map[lname].items():
            pred_data = inference_results[pred]
            inputs.append(pred_data.encoding[pred_info["output_index"] % len(pred_data.encoding)])
            if pred_data.call_count - 1 == 0:
                # Remove the tensor from the memory
                for v in inference_results[pred].encoding:
                    del v
                del inference_results[pred].encoding
        return inputs

    @staticmethod
    def update_inference_results(lname, outputs, inference_results: Dict[str, InferResults]):
        outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs]
        inference_results[lname].tensor = outputs

    @staticmethod
    def update_encoding_results(lname, encoding, inference_results: Dict[str, InferResults]):
        encoding = (
            encoding if isinstance(encoding, (list, tuple)) and not isinstance(encoding, Encoding) else [encoding]
        )
        inference_results[lname].encoding = encoding


# layers should be with precsion config! -> Layers , NO modification. We can build
# any new madification will be a new model.
class SModel(SaitamaModule):
    def __init__(self, layers: Dict[str, nn.Module], flow: ModelFlow):
        super().__init__()
        self.layers: Dict[str, SaitamaModule] = nn.ModuleDict(layers)
        self._flow = flow  # flow is assumed to be constant. any modification will requires new static mapping.
        self.static_mapping = create_static_mapping(flow)
        self.infer_utils = InferUtils
        self.model_checks = ModelChecks()

    @staticmethod
    def inputs_as_dict(input_order: List[str], *inputs: TensorList) -> TensorDict:
        return {source: inp for source, inp in zip(input_order, inputs)}

    def get_custom_static_mapping(self, custom_outputs=None, custom_inputs=None):
        return create_static_mapping(self._flow, custom_outputs, custom_inputs)

    def forward(
        self,
        inputs: TensorDict,
        static_mapping: StaticMappings = None,
        forward_encoding: bool = False,
        verify_encoding: bool = False,
    ):
        """Run the model with the given inputs.

        Args:
            inputs (TensorDict): Keys are the input nodes names and values are the tensors.
            static_mapping (StaticMappings, optional): Can be passed down to run partial portions of the model. Defaults to None.
                use get_custom_static_mapping to create a custom static mapping with desired inputs & outputs.
            forward_encoding (bool, optional): Indicates if the model should run the forward encoding. Defaults to False.
            verify_encoding (bool, optional): Indicates if the model should verify the encoding. Verify will trigger graph break on compilation. Defaults to False.

        Returns:
            _type_: _description_
        """
        if static_mapping is None:
            # We need to use primitive types for the static mapping in dynamo
            static_mapping = self.static_mapping

        # inputs = self.inputs_as_dict(self.static_mapping.sources, *positinal_inputs)
        # NOTE: raise an error if the user tries to overwrite the inputs
        # inputs.update(explicit_inputs)

        # Check consistency of the inputs
        self.model_checks.check_inputs_consistency(static_mapping, inputs)

        # Loadning the tensors
        inference_results = self.infer_utils.create_inference_results(static_mapping, inputs)

        # Torch Dynamo Safe
        results = self._run_main_flow(
            static_mapping,
            inference_results,
            forward_encoding=forward_encoding,
            forward_data=True,
            verify_encoding=verify_encoding,
        )

        return results

    def _run_main_flow(
        self,
        static_mapping: StaticMappings,
        inference_results,
        forward_encoding: bool = False,
        forward_data: bool = True,
        verify_encoding: bool = False,
    ) -> ModelOutput:
        """This method runs the main flow of the model with static mapping."""
        encoding_kwargs = {"verify_encoding": verify_encoding}
        for source in (*static_mapping.data_sources, *static_mapping.constant_sources):
            if forward_encoding:
                encoding_results = self.layers[source].forward_encoding(**encoding_kwargs)
                self.infer_utils.update_encoding_results(source, encoding_results, inference_results)
            if forward_data:
                results = self.layers[source](*inference_results[source].tensor)
                self.infer_utils.update_inference_results(source, results, inference_results)

        for lname in static_mapping.main_flow:
            if forward_encoding:
                encoding_inputs = self.infer_utils.get_encoding_inputs(
                    static_mapping.pred_mapping, lname, inference_results
                )
                encoding_outputs = self.layers[lname].forward_encoding(*encoding_inputs, **encoding_kwargs)
                self.infer_utils.update_encoding_results(lname, encoding_outputs, inference_results)
            if forward_data:
                current_inputs = self.infer_utils.get_layers_inputs(
                    static_mapping.pred_mapping, lname, inference_results, static_mapping.output_mapping
                )
                outputs = self.layers[lname](*current_inputs)
                self.infer_utils.update_inference_results(lname, outputs, inference_results)

        if forward_data:
            results = self.infer_utils.order_results(static_mapping.output_mapping, inference_results)
            return results
        else:
            return

    def forward_encoding(
        self,
        static_mapping: StaticMappings = None,
        verify_encoding: bool = False,
    ):
        if static_mapping is None:
            # We need to use primitive types for the static mapping in dynamo
            static_mapping = self.static_mapping

        inference_results = self.infer_utils.create_inference_results(static_mapping, None)
        self._run_main_flow(
            static_mapping,
            inference_results,
            forward_encoding=True,
            forward_data=False,
            verify_encoding=verify_encoding,
        )
