from typing import Dict, List, Literal, Union

from numpy.typing import ArrayLike
from pydantic.v1 import BaseModel, Field, PrivateAttr, ValidationError, validator

from hailo_model_optimization.acceleras.utils.acceleras_definitions import TrackerStage, TrackerType

NPZ = Dict[str, ArrayLike]


class BaseModificationTracker(BaseModel):
    tracker_type: TrackerType = None
    stage: TrackerStage = None
    lname: str = None

    class Config:
        extra = "forbid"

    def get_keys(self) -> Dict[str, str]:
        """
        Returns a subset of the class prorerties & values that points to a key in the modification params.
        """
        return dict()


class FoldNormalizationTracker(BaseModificationTracker):
    tracker_type: Literal[TrackerType.FOLD_NORMALIZATION] = TrackerType.FOLD_NORMALIZATION
    kernel_key: str
    bias_key: str
    apply_on_input: bool

    def get_keys(self):
        return {
            "kernel_key": self.kernel_key,
            "bias_key": self.bias_key,
        }


class MatrixMultiplicationTracker(BaseModificationTracker):
    tracker_type: Literal[TrackerType.MATRIX_MULTIPLICATION] = TrackerType.MATRIX_MULTIPLICATION
    kernel_key: str
    transpose: bool
    apply_on_input: bool

    def get_keys(self):
        return {
            "kernel_key": self.kernel_key,
        }


class GatherTracker(BaseModificationTracker):
    tracker_type: Literal[TrackerType.GATHER] = TrackerType.GATHER
    indices_key: str
    apply_on_input: bool

    def get_keys(self):
        return {
            "indices_key": self.indices_key,
        }


class SplitTracker(BaseModificationTracker):
    tracker_type: Literal[TrackerType.SPLIT] = TrackerType.SPLIT
    mini_convs: List[str]
    ew_add_after: bool = False


modification_tracker_types = Union[FoldNormalizationTracker, MatrixMultiplicationTracker, GatherTracker, SplitTracker]


tracker_types_map = {
    TrackerType.FOLD_NORMALIZATION.value: FoldNormalizationTracker,
    TrackerType.MATRIX_MULTIPLICATION.value: MatrixMultiplicationTracker,
    TrackerType.GATHER.value: GatherTracker,
    TrackerType.SPLIT.value: SplitTracker,
}


class AlgoModificationTracker(BaseModel):
    layers: Dict[str, List[modification_tracker_types]] = Field(dict())
    _modification_params: NPZ = PrivateAttr({})

    class Config:
        extra = "forbid"

    @validator("layers", pre=True)
    def validate_layers(cls, v):
        if isinstance(v, dict):
            for trackers in v.values():
                if isinstance(trackers, list):
                    to_remove = []
                    for i, tracker in enumerate(trackers):
                        if isinstance(tracker, dict) and tracker.get("tracker_type", None) in tracker_types_map:
                            tracker_type = tracker.get("tracker_type")
                            tracker["tracker_type"] = TrackerType(tracker_type)
                            try:
                                tracker_types_map[tracker_type].validate(tracker)
                            except ValidationError:
                                to_remove.append(i)
                        else:
                            to_remove.append(i)
                    for i in reversed(to_remove):
                        trackers.pop(i)
        return v

    def get_modification_params(self) -> NPZ:
        return self._modification_params

    def set_modification_params(self, modification_params: NPZ):
        self._modification_params = modification_params

    def has_modification_params(self) -> bool:
        return len(self._modification_params) > 0

    def update(self, other: "AlgoModificationTracker"):
        other_modification_params = other.get_modification_params()
        added_keys = dict()
        for lname, trackers in other.layers.items():
            if len(trackers) == 0:
                continue
            self.layers.setdefault(lname, [])
            for tracker in trackers:
                for attr, key in tracker.get_keys().items():
                    if key not in added_keys.keys():
                        base_key = key.split(":", 1)[0]
                        new_key = self.add_modification_param(base_key, other_modification_params[key])
                        added_keys[key] = new_key
                    else:
                        new_key = added_keys[key]
                    setattr(tracker, attr, new_key)
                self.layers[lname].append(tracker)

    def set_stage(self, stage: TrackerStage):
        for trackers in self.layers.values():
            for tracker in trackers:
                tracker.stage = stage

    def append(self, lname: str, tracker: modification_tracker_types):
        tracker.lname = lname
        self.layers.setdefault(lname, []).append(tracker)

    def add_modification_param(self, base_key: str, value: ArrayLike):
        i = (
            max(
                [int(k.split(":", 1)[-1]) for k in self._modification_params.keys() if k.split(":", 1)[0] == base_key]
                + [-1]
            )
            + 1
        )
        key = f"{base_key}:{i}"
        self._modification_params[key] = value
        return key
