import json
from pathlib import Path, PosixPath
from typing import TYPE_CHECKING, Tuple, Union

import yaml
from pydantic.v1 import Field

from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import PostProcessConfig
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import OpStates, OptimizationTarget
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    AlgoModificationTracker,
)
from hailo_model_optimization.acceleras.utils.params_loader import load_params, save_params
from hailo_model_optimization.algorithms.algorithm_base import AlgoResults
from hailo_model_optimization.tools.base_memento import BaseMemento

if TYPE_CHECKING:
    from hailo_model_optimization.flows.optimization_flow import OptimizationFlow


class AccelerasMemento(BaseMemento):
    """This should give enough information so it can continue from a point"""

    hn: PosixPath
    acceleras_params: PosixPath
    nms_config: dict = Field({}, description="NMS configuration")
    optimization_target: OptimizationTarget = Field(
        OptimizationTarget.SAGE, description="OptimizationTarget describes the HW architecture"
    )
    lora_adapter_name: str = Field(None, description="Lora adapter name")


class FlowMemento(BaseMemento):
    fp_hn: PosixPath = Field(default_factory=PosixPath)
    hn: PosixPath = Field(default_factory=PosixPath)
    hw_params: PosixPath = Field(default_factory=PosixPath)
    fp_params: PosixPath = Field(default_factory=PosixPath)
    acceleras_params: PosixPath = Field(default_factory=PosixPath)
    flow_results: PosixPath
    models: AccelerasMemento = Field(default_factory=AccelerasMemento)
    flow_info: PosixPath = Field(default_factory=PosixPath)
    mo_config: PosixPath = Field(default_factory=PosixPath)
    algorithm_stats: PosixPath = Field(default_factory=PosixPath)
    original_input_shapes: dict = Field(default_factory=dict, description="TODO understand this logic")
    modifications_meta_data: PosixPath = Field(default_factory=PosixPath)
    modifications_params: PosixPath = Field(default_factory=PosixPath)


class AccelerasOriginator:
    """In charge of serializing the model"""

    def save(self, model: HailoModel, path: PosixPath) -> AccelerasMemento:
        acceleras_path = path.joinpath("acceleras.hdf5")
        model_hn = path.joinpath("acceleras_model.hn")
        hn = model.export_hn()

        with model_hn.open("w") as fp:
            json.dump(hn, fp)

        npz = model.export_acceleras()
        save_params(acceleras_path, npz)
        nms_config = model.nms_config.dict() if model.nms_config else {}
        optimization_target = model.optimization_target
        lora_adapter_name = model.lora_adapter_name

        return AccelerasMemento(
            base_path=path,
            hn=model_hn,
            acceleras_params=acceleras_path,
            nms_config=nms_config,
            optimization_target=optimization_target,
            lora_adapter_name=lora_adapter_name,
        )

    @staticmethod
    def load_params(memento: AccelerasMemento) -> Tuple[dict, dict, PostProcessConfig, OptimizationTarget, str]:
        nms_config = PostProcessConfig(**memento.nms_config) if memento.nms_config else None
        npz = load_params(memento.acceleras_params)
        if memento.hn.exists():
            with memento.hn.open("r") as fd:
                hn = json.load(fd)
        return hn, npz, nms_config, memento.optimization_target, memento.lora_adapter_name

    def restore(self, memento: AccelerasMemento) -> HailoModel:
        hn, npz, nms_config, optimization_target, lora_adapter_name = self.load_params(memento)
        model = HailoModel(
            hn_dict=hn,
            nms_config=nms_config,
            optimization_target=optimization_target,
            lora_adapter_name=lora_adapter_name,
        )
        model.import_acceleras(npz)
        if OpStates.QUANTIZED in model.supported_states:
            model.set_lossy()
        shapes = [(None,) + shape for shape in model.get_input_shapes()]
        model.compute_output_shape(shapes)
        model.build(shapes)

        return model


class OptFlowOriginator:
    def __init__(self, flow: "OptimizationFlow") -> None:
        self.flow = flow
        self.model_serializer = AccelerasOriginator()

    def save(self, path: Path) -> FlowMemento:
        path = Path(path)

        # save_fp
        fp_params = path.joinpath("fp.hdf5")
        npz = self.flow.get_fp_params()
        save_params(fp_params, npz)

        fp_hn = path.joinpath("model_fp.hn")
        with fp_hn.open("w") as fp:
            json.dump(self.flow.get_fp_hn(), fp)

        # SAVE model
        model_memento = (
            self.flow._model_memento
            if self.flow._model_memento and self.flow._model_memento.alive
            else self.model_serializer.save(self.flow._model, path)
        )

        # Save mo config
        mo_config = path.joinpath("new_mo_config.yaml")
        with mo_config.open("w") as fp:
            yaml.safe_dump(json.loads(self.flow._parsed_config.json()), fp, width=80, indent=4)

        # save flow results
        algorithm_results = path.joinpath("results.json")
        with algorithm_results.open("w") as fp:
            json.dump([res.dict() for res in self.flow._flow_results], fp)

        # Save flow statistics
        algorithm_stats = path.joinpath("stats.hdf5")
        if self.flow.params_statistics:
            save_params(algorithm_stats, self.flow.params_statistics)

        # Save HW params this will happen after step 1
        hw_params = path.joinpath("hw_params.hdf5")
        if self.flow._quant_params:
            save_params(hw_params, self.flow._quant_params)

        acceleras_params = path.joinpath("acceleras_params.hdf5")
        if self.flow._acceleras_params:
            save_params(acceleras_params, self.flow._acceleras_params)

        quantize_hn = path.joinpath("quantize_model.hn")

        if self.flow._hn:
            with quantize_hn.open("w") as fp:
                json.dump(self.flow._hn, fp)

        # Save modifications meta data
        modifications_meta_data = path.joinpath("modifications_meta_data.yaml")
        with modifications_meta_data.open("w") as fp:
            yaml.safe_dump(json.loads(self.flow.modifications_meta_data.json()), fp, width=80, indent=4)
        modifications_params = path.joinpath("modifications_params.hdf5")
        if self.flow.modifications_meta_data.has_modification_params():
            save_params(modifications_params, self.flow.modifications_meta_data.get_modification_params())

        return FlowMemento(
            base_path=path,
            fp_params=fp_params,
            fp_hn=fp_hn,
            hn=quantize_hn,
            hw_params=hw_params,
            acceleras_params=acceleras_params,
            models=model_memento,
            mo_config=mo_config,
            flow_results=algorithm_results,
            algorithm_stats=algorithm_stats,
            original_input_shapes=self.flow.original_input_shapes,
            modifications_meta_data=modifications_meta_data,
            modifications_params=modifications_params,
        )

    def restore(self, memento: FlowMemento, tf_safe: bool = False) -> None:
        # Load FP model parameters
        if memento.fp_params.exists():
            fp_params = load_params(memento.fp_params)
            self.flow.set_fp_params(fp_params)

        if memento.fp_hn.exists():
            with memento.fp_hn.open("r") as fp:
                hn_fp = json.load(fp)
                self.flow.set_fp_hn(hn_fp)

        # Load MO configuration
        if memento.mo_config.exists():
            with open(memento.mo_config, "r") as fp:
                mo_config_data = yaml.safe_load(fp)
                self.flow._parsed_config = ModelOptimizationConfig(**mo_config_data)

        # Load flow results
        if memento.flow_results.exists():
            with open(memento.flow_results, "r") as fp:
                flow_results = json.load(fp)
                self._flow_results = [AlgoResults(**res) for res in flow_results]

        # Load flow statistics
        if memento.algorithm_stats.exists():
            self.flow.params_statistics = load_params(memento.algorithm_stats)

        # Load Original Input shapes
        self.flow.original_input_shapes = memento.original_input_shapes

        # Restore final params for the model.
        if memento.hw_params.exists():
            hw_params = load_params(memento.hw_params)
            self.flow.set_quant_params(hw_params)

        if memento.acceleras_params.exists():
            self.flow._acceleras_params = load_params(memento.acceleras_params)

        if memento.hn.exists():
            with memento.hn.open("r") as fp:
                hn = json.load(fp)
                self.flow.set_hn(hn)

        # Load modifications meta data
        if memento.modifications_meta_data.exists():
            with open(memento.modifications_meta_data, "r") as fp:
                modifications_meta_data = yaml.safe_load(fp)
                self.flow._modifications_meta_data = AlgoModificationTracker(**modifications_meta_data)
            if memento.modifications_params.exists():
                modifications_params = load_params(memento.modifications_params)
                self.flow._modifications_meta_data.set_modification_params(modifications_params)

        # Restore working model
        if not tf_safe and memento.models:
            self.flow._model = self.model_serializer.restore(memento.models)
        else:
            self.flow._model_memento = memento.models


def save_acceleras_model(model: HailoModel, dir_path: Union[Path, str]) -> AccelerasMemento:
    dir_path = Path(dir_path)
    dir_path.mkdir(parents=True, exist_ok=True)
    serializer = AccelerasOriginator()
    memento = serializer.save(model, dir_path)
    memento_path = dir_path.joinpath("memento.json")
    with memento_path.open("w") as f:
        f.write(memento.json(indent=4))
    return memento_path


def load_acceleras_model(memento_path: Union[Path, str]) -> HailoModel:
    memento_path = Path(memento_path)
    if memento_path.suffix != ".json":
        if memento_path.is_dir():
            memento_path = memento_path.joinpath("memento.json")
        else:
            raise ValueError("Path should be a directory or a json file")
    with memento_path.open("r") as f:
        memento = AccelerasMemento.parse_raw(f.read())
    serializer = AccelerasOriginator()
    return serializer.restore(memento)
