from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, Iterator, List, Optional, Set, TypeVar, Union

import numpy as np
import torch
import torch.nn as nn

T = TypeVar("T", Enum, str)
Tags = Union[Set[T], T]


def KeyHandlerFuncProto(*args: np.ndarray, **kwargs) -> np.ndarray: ...


@dataclass
class KeyHandler:
    saitama_key: str
    acceleras_keys: Union[str, List[str]]
    handler: Callable = None  # Calalble like KeyHandlerFuncProto
    tags: Optional[Tags] = frozenset()
    default_factory: Optional[Callable] = None


class TranslatorUtils:
    @classmethod
    def create_state_dict(
        cls,
        reference_dict,
        mapping: List[KeyHandler],
        hn_element: dict,
        *,
        tags: Optional[Tags] = None,
        device=None,
        dtype=None,
    ) -> Dict[str, torch.Tensor]:
        tags = _wrap_handlers(tags)
        state_dict = {}
        for key_mapping in cls.filter_mapping(mapping, tags):
            try:
                saitama_key = key_mapping.saitama_key
                params = cls.fetch_acceleras_params(
                    key_mapping.acceleras_keys, reference_dict, key_mapping.default_factory
                )
                saitama_value = cls.get_saitama_value(saitama_key, key_mapping.handler, *params, hn_element=hn_element)
                state_dict[saitama_key] = torch.tensor(saitama_value, device=device, dtype=dtype)
            except Exception as e:
                e.args = (f"Error in key: {saitama_key}",) + e.args
                raise
        return state_dict

    @staticmethod
    def get_saitama_value(saitama_key, handler_func, *args, **kwargs):
        handler_func = handler_func if handler_func is not None else lambda x, **kw: x
        try:
            saitama_value = handler_func(*args, **kwargs)
        except Exception as e:
            if handler_func is None and len(args) > 1:
                raise ValueError(f"Handler is required for {saitama_key}, because it uses multiple acceleras keys")
            e.args = f"KeyHandler's handler function: {e}"
            raise
        return saitama_value

    @staticmethod
    def fetch_acceleras_params(
        acceleras_keys: Union[str, List[str]],
        reference_dict: dict,
        default_factory: Optional[Callable] = None,
    ) -> List[np.ndarray]:
        if isinstance(acceleras_keys, str):
            acceleras_keys = [acceleras_keys]
        try:
            params = [reference_dict[k] for k in acceleras_keys]
        except KeyError as e:
            if default_factory is not None:
                # Debug log?
                params = [default_factory()]
            else:
                missing_keys = set(acceleras_keys) - set(reference_dict.keys())
                e.args = (f"Missing keys: {missing_keys}",)
                raise e
        return params

    @staticmethod
    def filter_mapping(mapping: List[KeyHandler], tags: Tags) -> Iterator[KeyHandler]:
        is_valid = True
        dulplicates = set()
        for m in mapping:
            existing_keys = set()
            val_tags = _wrap_handlers(m.tags)
            if val_tags.issubset(tags):
                if m.saitama_key in existing_keys:
                    is_valid = False
                    dulplicates.add(m.saitama_key)
                if is_valid:
                    yield m
                existing_keys.add(m.saitama_key)
        if dulplicates:
            raise ValueError(f"The following keys has duplicate entries: {dulplicates}")

    @staticmethod
    def validate_broadcast_and_load(module: nn.Module, state_dict: Dict[str, torch.Tensor]) -> nn.Module:
        reference = module.state_dict()
        extra_state_key = "._extra_state"
        extra_state = {k[: -len(extra_state_key)]: v for k, v in reference.items() if k.endswith(extra_state_key)}
        for key, val in state_dict.items():
            if key in reference:
                if key in reference:
                    ref_param = reference[key]
                if ref_param.shape != val.shape:
                    if len(val.shape) == 0 or torch.squeeze(reference[key]).shape == val.shape:
                        reference[key] = val.expand_as(ref_param)
                    else:
                        raise ValueError(f"Shape mismatch: {key} expected {ref_param.shape} got {val.shape}")
                else:
                    reference[key] = val
            elif extra_state:
                k1, k2 = key.rsplit(".", 1)
                # key without the prefix because the extra_state key is already prefixed
                if k2 in extra_state[k1]:
                    extra_state[k1][k2] = val
                else:
                    raise ValueError(f"Key {k2} not found in the extra state dictionary")
            else:
                reference[key] = val
        module.load_state_dict(reference, strict=True)
        return module


def _wrap_handlers(tags: Tags[T]) -> Set[T]:
    if not isinstance(tags, (set, frozenset)):
        return {tags}
    return tags
