import pickle
from enum import Enum
from pathlib import Path
from typing import BinaryIO, Dict, Union

import h5py
import numpy as np


def _save_npz(fo, dict_):
    np.savez(fo, **dict_)


def _load_npz(fo, **kwargs) -> dict:
    return dict(np.load(fo, **kwargs))


def _save_pkl(fo, dict_):
    pickle.dump(dict_, fo)


def _load_pkl(fo, **kwargs):
    return pickle.load(fo)


def _save_hdf5(fo, dict_):
    with h5py.File(fo, "w") as fp:
        for k, v in dict_.items():
            if isinstance(v, np.ndarray) and v.dtype.str.startswith("<U"):
                v = v.astype("S")
            fp.create_dataset(k, data=v)


def _load_hdf5(hdf5: Union[str, Path, BinaryIO, h5py.File], **kwargs) -> dict:
    hdf5 = h5py.File(hdf5, "r", **kwargs)
    return _load_hdf5_recursive(hdf5)


def _load_hdf5_recursive(hdf5: h5py.File, prefix="") -> dict:
    h5_as_dict = {}
    if prefix and not prefix.endswith("/"):
        prefix += "/"
    for key, val in hdf5.items():
        key = f"{prefix}{key}"
        if type(val) is h5py._hl.dataset.Dataset:
            v = np.array(val)
            if v.dtype.str.startswith("|S") or v.dtype.str.startswith("|O"):
                v = v.astype("U")
            h5_as_dict[key] = v
        else:
            h5_as_dict.update(_load_hdf5_recursive(val, prefix=key))
    return h5_as_dict


class ParamSerializationType(str, Enum):
    HDF5 = ".hdf5"
    NPZ = ".npz"
    PICKLE = ".pkl"

    def clean_suffix(self):
        return self.value[1:]


PARAM_SERIALIZER_BY_TYPE = {
    ParamSerializationType.HDF5: _save_hdf5,
    ParamSerializationType.NPZ: _save_npz,
    ParamSerializationType.PICKLE: _save_pkl,
}


PARAM_DESERIALIZER_BY_TYPE = {
    ParamSerializationType.HDF5: _load_hdf5,
    ParamSerializationType.NPZ: _load_npz,
    ParamSerializationType.PICKLE: _load_pkl,
}


def _get_type_from_filename(fo: Union[str, Path, BinaryIO]):
    type_ = None
    if isinstance(fo, (str, Path)):
        path = Path(fo)
        suffix = path.suffix.lower()
        if suffix in ParamSerializationType._value2member_map_:
            type_ = suffix
    return type_


def load_params(fo: Union[str, Path, BinaryIO], type_=None, **kwargs):
    if type_ is None:
        type_ = _get_type_from_filename(fo)
    if type_ is None:
        raise ValueError("If filename is a file object, type must be provided")
    return PARAM_DESERIALIZER_BY_TYPE[type_](fo, **kwargs)


def save_params(
    fo: Union[str, Path, BinaryIO],
    dict_: Dict[str, np.ndarray],
    type_: ParamSerializationType = None,
):
    if type_ is None:
        type_ = _get_type_from_filename(fo)
    if type_ is None:
        type_ = ParamSerializationType.HDF5
    PARAM_SERIALIZER_BY_TYPE[type_](fo, dict_)
