from abc import abstractmethod

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInitializationError


def get_hn_padding(params):
    stride_align = StrideAlignType.NW
    padding = PaddingType(params["padding"])
    if padding == PaddingType.SAME_TENSORFLOW:
        padding = PaddingType.SAME
        stride_align = StrideAlignType.SE
    return padding, stride_align


def set_hn_padding_stride_align(params, padding, stride_align):
    if padding == PaddingType.SAME and stride_align == StrideAlignType.SE:
        params["padding"] = PaddingType.SAME_TENSORFLOW.value
    else:
        params["padding"] = padding.value


class BaseParamsWrap:
    """
    a base params wrapper - that has npz ( native params) and qnpz parmas
    will handle - import(from to npz/qnpz to acceleras model)/export ( from acceleras model to npz/qnpz)
    """

    def __init__(self, params, base_layer_mapping=None):
        if params is None:
            params = dict()
        params = dict(params)
        self.check_params_kind(params)
        self.params: dict = params
        self._per_layer_params = None
        self._base_layer_mapping = base_layer_mapping if base_layer_mapping is not None else dict()

    @staticmethod
    @abstractmethod
    def check_params_kind(params):
        # check params kind is ok, and if not exist add it to the params
        pass

    def get_param(self, layer_name, param_name, default_val=None):
        return self.params.get(f"{layer_name}/{param_name}:0", default_val)

    def set_param(self, layer_name, param_name, pval):
        self.params[f"{layer_name}/{param_name}:0"] = pval

    def write_params(self, layer_name, params_exported):
        """Used to write all the params exported by layer"""
        for key, param in params_exported.items():
            self.set_param(layer_name, key, param)

    def _get_base_key_mapping(self):
        lname_keys = {}
        key_map = {}
        for key in self.params.keys():
            if len(key.split("/")) <= 2:
                key_map[key] = key
                continue
            lname = "/".join(key.split("/")[:2])
            internal_key = "/".join(key.split("/")[2:])
            lname_keys.setdefault(lname, [])
            lname_keys[lname].append(internal_key)
        for lname in lname_keys.keys() | self._base_layer_mapping.keys():
            base_lname = self._base_layer_mapping.get(lname, lname)
            path = [lname]
            while base_lname not in path:
                path.append(base_lname)
                base_lname = self._base_layer_mapping.get(base_lname, base_lname)
            for base_lname in path[::-1]:
                for internal_key in lname_keys.get(base_lname, []):
                    key_map[f"{lname}/{internal_key}"] = f"{base_lname}/{internal_key}"
        return key_map

    def get_layer_params(self, layer_name):
        """
        get all layer params
        Args:
            layer_name:

        Returns

        """
        if self._per_layer_params is None:
            self._per_layer_params = dict()
            key_map = self._get_base_key_mapping()
            for layer_key, param_key in key_map.items():
                lname = "/".join(layer_key.split("/")[:2])
                internal_key = "/".join(layer_key.split("/")[2:])
                if internal_key.endswith(":0"):
                    internal_key = internal_key[:-2]
                self._per_layer_params.setdefault(lname, dict())
                self._per_layer_params[lname][internal_key] = self.params[param_key]
        if layer_name not in self._per_layer_params:
            # TODO: add warning?
            layers_params = dict()
        else:
            layers_params = self._per_layer_params[layer_name]
        return layers_params


class NpzWrap(BaseParamsWrap):
    @staticmethod
    def check_params_kind(params):
        params_kind = params.get("params_kind", None)
        if params_kind is not None and params_kind[0] not in [0, 1, 3, 4]:
            raise AccelerasInitializationError(f"Trying to load params of kind {params_kind} to network ")
        if params_kind is None:
            params["params_kind"] = np.array([0])


class QNpzWrap(BaseParamsWrap):
    @staticmethod
    def check_params_kind(params):
        params_kind = params.get("params_kind", None)
        if params_kind is not None and params_kind[0] != 2:
            raise AccelerasInitializationError(f"Trying to load params of kind {params_kind} to network ")
        if params_kind is None:
            params["params_kind"] = np.array([2])


class LayerParams:
    """
    This class provides a wrapper that exposes the parameters only of the given layer.
    Helps the code to be more readable and the layer's params to be more accessible in a layer context
    """

    SPECIAL_KEYS = {"params_kind"}

    def __init__(self, params, layer_name):
        """
        Initialize a LayerParams with npz and layer_name

        Args:
            params: params dictionary with model parameters
            layer_name: should be in <scope>/<layer_name> format.
                        if given <layer_name> without scope, it tried to extract scope name from the npz

        """
        self._params = params
        self._layer_name = self._get_full_layer_name(layer_name)
        self._keys = None

    def _get_full_layer_name(self, layer_name):
        """
        get a full layer name <scope>/<layer_name>. When scope is not given, it is extracted from the npz

        Args:
            layer_name: Either <layer_name> or <scope>/<layer_name>

        Returns: <scope>/<layer_name>

        """
        # When layer name has 2 parts - we assume <scope/layer_name>
        if len(layer_name.split("/")) == 2:
            return layer_name
        # When layer name has only 1 part, we need to get the scope name from the npz
        elif len(layer_name.split("/")) == 1:
            # Get all the scope names from all the parameters with the given layer_name
            scopes = {key.split(f"/{layer_name}/")[0] for key in self._params.keys() if layer_name in key}
            if len(scopes) > 1:
                raise ValueError(f"multiple scopes had layer_name {layer_name}. scopes: {scopes}")
            elif len(scopes) == 0:
                # The layer didn't have any parameters, check if the entire npz has only 1 scope.
                scopes = {key.split("/")[0] for key in self._params.keys() if key not in self.SPECIAL_KEYS}
                if len(scopes) != 1:
                    raise ValueError("Couldn't extract scope from npz")
            scope = scopes.pop()
            return f"{scope}/{layer_name}"
        else:
            raise ValueError(f"layer_name {layer_name} had too many '/'")

    def __getitem__(self, item):
        key = self._resolve_key(item)
        return self._params[key]

    def _is_key(self, item):
        return item.startswith(f"{self._layer_name}/") and item.endswith(":0")

    def keys(self):
        if self._keys is None:
            self._keys = [key[len(self._layer_name) + 1 : -2] for key in self._params.keys() if self._is_key(key)]
        return self._keys

    def get(self, key, default=None):
        key = self._resolve_key(key)
        return self._params.get(key, default)

    def _resolve_key(self, key):
        return f"{self._layer_name}/{key}:0"
