#!/usr/bin/env python

from builtins import str
from pprint import pformat

import numpy as np

from hailo_sdk_common.exceptions.exceptions import SDKCommonException
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.serialization.numpy_serialization import hailo_np_load
from hailo_sdk_common.targets.inference_targets import ParamsKinds


class ModelParamsException(SDKCommonException):
    pass


class InnerParams:
    def __init__(self, partial_params):
        self._load_params(partial_params)

    def _set_inner_property(self, inner_prop, parted_d):
        def inner_property(self):
            inner_prop_slash = f"{inner_prop}/"
            return InnerParams(
                {"/".join(k.split("/")[1:]): parted_d[k] for k in parted_d.keys() if k.startswith(inner_prop_slash)},
            )

        return property(inner_property)

    def _load_params(self, params):
        self._params = params
        parted_d = {k: v for k, v in params.items() if "/" in k}
        for k in set(parted_d.keys()):
            key_name = k.split("/")[0]
            setattr(type(self), key_name, self._set_inner_property(key_name, parted_d))
        short_d = {k: v for k, v in params.items() if "/" not in k}
        self.__dict__.update({k.split(":")[0]: v for k, v in short_d.items()})
        self.keys, self.values, self.items = self._params.keys, self._params.values, self._params.items

    def get(self, item, default_value=None):
        try:
            return self.__getitem__(item)
        except KeyError:
            return default_value

    def __contains__(self, item):
        if item in self._params:
            return True
        io_item = f"{item}:0"
        return io_item in self._params

    def __getitem__(self, item):
        io_item = f"{item}:0"
        if io_item in self._params:
            return self._params[io_item]
        return self._params[item]

    def __iter__(self):
        return iter(self._params)

    def __str__(self):
        return str(self._params)

    def __repr__(self):
        return repr(self._params)


class ModelParams:
    """Dict-like class that contains all parameters used by a model such as weights, biases, etc."""

    CONSTS_NAMES = [
        "epsilon",
        "leaky_alpha",
        "limvals_out",
        "limvals_in",
        "scale_bias",
        "bias_factor",
        "residue",
        "accumulator_offset",
        "input_activation_bits",
        "output_activation_bits",
        "weight_bits",
        "activation_threshold",
        "activation_delta_bias",
        "activation_clipping_percentile_values",
        "size_splits",
        "swish_beta",
        "activation_less_values",
        "padding_const_value",
        "hardsigmoid_alpha",
        "hardsigmoid_beta",
        "clip_min",
        "clip_max",
        "activation_greater_values",
        "pow_exponent",
        "zp_x_vals",
        "precision_split_zp",
    ]

    PARAMS_KIND_STR = "params_kind"
    PARAMS_MODE = "mode"
    SUPPORTED_STATES = "supported_states"
    OPTIMIZATION_TARGET = "optimization_target"

    VALUE_TO_PARAMS_KIND_DICT = {
        0: ParamsKinds.NATIVE,
        1: ParamsKinds.NATIVE_FUSED_BN,
        2: ParamsKinds.TRANSLATED,
        3: ParamsKinds.FP_OPTIMIZED,
        4: ParamsKinds.HAILO_OPTIMIZED,
        5: ParamsKinds.STATISTICS,
    }
    PARAMS_KIND_TO_VALUE_DICT = {val: key for key, val in VALUE_TO_PARAMS_KIND_DICT.items()}

    def __init__(self, params, executable_model_suffix=None, names_mapping=None):
        self._params = {}
        self._network_names = set()
        self._params_kind = None
        self._layers = None
        self._load_params(params, executable_model_suffix, names_mapping)

    def _set_dict_attr(self):
        self.keys, self.values, self.items = self._params.keys, self._params.values, self._params.items

    def _load_params(self, params, executable_model_suffix=None, names_mapping=None):
        self._params.update(params)
        if self.PARAMS_KIND_STR in params:
            self._set_params_kind_from_val()

        network_names = set(
            [
                key.split("/")[0]
                for key in params.keys()
                if key not in [self.PARAMS_MODE, self.PARAMS_KIND_STR, self.SUPPORTED_STATES, self.OPTIMIZATION_TARGET]
            ],
        )
        updated_keys = {}
        for key in params.keys():
            if key == self.PARAMS_KIND_STR:
                continue

            updated_key = key
            if "." in key:
                updated_key = key.replace(".", "_")
                updated_keys[key] = updated_key

            key_split = updated_key.split("/", 2)
            key_prefix = "/".join(key_split[:2])

            if names_mapping and key_prefix in names_mapping:
                updated_key = f"{names_mapping[key_prefix]}/{key_split[-1]}"

            if executable_model_suffix:
                key_parts = updated_key.split("/", 1)
                updated_key = f"{key_parts[0]}_{executable_model_suffix}/{key_parts[-1]}"

            if updated_key != key:
                self._params[updated_key] = self._params.pop(key)

        if updated_keys:
            first_item = next(iter(updated_keys.items()))
            default_logger().debug(
                f"Model params contained keys with '.' which isn't supported. {len(updated_keys)} "
                f"keys were modified, for example: {first_item[0]} was replaced with "
                f"{first_item[1]}.",
            )

        self._network_names.update(network_names)
        self._set_dict_attr()
        self._replace_none_with_nan()
        self._fix_params_dtype()
        self._set_properties_from_params()

    def _replace_none_with_nan(self):
        for key, val in self._params.items():
            if (val is None) or (isinstance(val, np.ndarray) and val.shape == () and val == np.array(None)):
                self._params[key] = np.nan
                default_logger().debug(f"Found None value in params: {key} replacing it with NaN")

    def _fix_params_dtype(self):
        for key, val in self._params.items():
            if isinstance(val, np.ndarray) and val.dtype in [np.float16, np.float64, np.float128]:
                self._params[key] = val.astype(np.float32)

    def _layer_property_generator(self, layer):
        def layer_property(self):
            layer_name = f"{layer}/"
            return InnerParams(
                {"/".join(k.split("/")[2:]): self._params[k] for k in self._params.keys() if k.startswith(layer_name)},
            )

        return property(layer_property)

    def _param_property_generator(self, param):
        def param_property(self):
            key = f"{param}:"
            return InnerParams(
                {
                    "/".join(k.split("/")[:2]): self._params[k]
                    for k in self._params.keys()
                    if k.split("/")[-1].startswith(key)
                },
            )

        return property(param_property)

    def _network_name_property_generator(self, network_name):
        def network_name_property(self):
            network_name_slash = f"{network_name}/"
            return InnerParams(
                {
                    "/".join(k.split("/")[1:]): self._params[k]
                    for k in self._params.keys()
                    if k.startswith(network_name_slash)
                },
            )

        return property(network_name_property)

    def _set_properties_from_params(self):
        # Get the last hierarchy of the parameter (for example: bias, qx, etc...)
        # And then erase the in/out index
        # Do it for all parameters and uniq it
        self._properties = list(
            set(
                [
                    key.split("/")[-1].split(":")[0]
                    for key in self._params.keys()
                    if key
                    not in [self.PARAMS_MODE, self.PARAMS_KIND_STR, self.SUPPORTED_STATES, self.OPTIMIZATION_TARGET]
                ],
            ),
        )

        for param in self._properties:
            setattr(type(self), param, self._param_property_generator(param))

        for layer in self.layers:
            setattr(type(self), layer, self._layer_property_generator(layer))

        for network_name in self._network_names:
            setattr(type(self), network_name, self._network_name_property_generator(network_name))

    def _params_kind_to_param(self):
        if self._params_kind in self.PARAMS_KIND_TO_VALUE_DICT.keys():
            return [self.PARAMS_KIND_TO_VALUE_DICT[self._params_kind]]

        raise ModelParamsException(f"Illegal params_kind: {self._params_kind}")

    def _get_params_kind_from_val(self, value):
        if len(value) == 1 and value[0] in self.VALUE_TO_PARAMS_KIND_DICT.keys():
            return self.VALUE_TO_PARAMS_KIND_DICT[value[0]]
        else:
            raise ModelParamsException(f"Illegal value to set for params kind: {value}")

    def _set_params_kind_from_val(self):
        value = self._params[self.PARAMS_KIND_STR]
        params_kind = self._get_params_kind_from_val(value)
        self.set_params_kind(params_kind)

    def set_layer_kernel(self, layer, kernel, io_index=0):
        key = f"{layer}/kernel:{io_index}"
        self._params[key] = kernel
        self._set_dict_attr()
        self._set_properties_from_params()

    def set_layer_bias(self, layer, bias, io_index=0):
        key = f"{layer}/bias:{io_index}"
        self._params[key] = bias
        self._set_dict_attr()
        self._set_properties_from_params()

    def set_params_kind(self, params_kind):
        self._params_kind = params_kind
        self._params[self.PARAMS_KIND_STR] = self._params_kind_to_param()
        self._set_dict_attr()
        setattr(self, self.PARAMS_KIND_STR, self._params[self.PARAMS_KIND_STR])

    def update(self, params):
        self._load_params(params)

    def get(self, item, default_value=None):
        try:
            return self.__getitem__(item)
        except KeyError:
            return default_value

    def get_consts(self):
        result = {}
        for const_name in type(self).CONSTS_NAMES:
            if const_name in self.properties:
                result[const_name] = getattr(self, const_name)
        return result

    def get_points_count(self):
        x_points_key_suffix = "/output_stage/piecewise/x_points:0"
        return {
            k[: -len(x_points_key_suffix)]: v.shape[-1]
            for k, v in self._params.items()
            if k.endswith(x_points_key_suffix)
        }

    def remove(self, layer):
        params = self.params
        updated_dict = {
            k: params[k]
            for k in params.keys()
            if k == "params_kind" or "/".join(k.split("/")[:2]) != "/".join(layer.split("/")[:])
        }

        # mutate model params object with update dict
        self._params = updated_dict
        self._set_dict_attr()
        self._set_properties_from_params()

    def add(self, layer, layer_dict):
        # add network & layer to the key before inserting in the dict
        layer_dict = {"{}{}".format(layer + "/", k): layer_dict[k] for k in layer_dict.keys()}
        params = self.params
        params.update(layer_dict)

        # mutate model params object with update dict
        self._params = params
        self._set_dict_attr()
        self._set_properties_from_params()

    @property
    def params_kind_enum(self):
        return self._params_kind

    @property
    def layers(self):
        if self._layers is None:
            self._layers = list(
                set(
                    [
                        "/".join(key.split("/")[:2])
                        for key in self._params
                        if key
                        not in [self.PARAMS_MODE, self.PARAMS_KIND_STR, self.SUPPORTED_STATES, self.OPTIMIZATION_TARGET]
                    ],
                ),
            )
        return self._layers

    @property
    def properties(self):
        return list(self._properties)

    @property
    def params(self):
        return dict(self._params)

    @property
    def network_names(self):
        return set(self._network_names)

    def __iter__(self):
        return iter(self.layers)

    def __contains__(self, item):
        # supporting everything in __getitem__ except slice
        if isinstance(item, tuple):
            item = "/".join(item)
        if item in self._params:
            return True
        if item in self.layers:
            return True
        if item in self.properties:
            return True
        if item in self.network_names:
            return True
        return False

    def __getitem__(self, item):
        if isinstance(item, slice):
            # network_name is None so that the original network names are kept
            return ModelParams({k: np.copy(self._params[k]) for k in list(self._params.keys())[item]})
        if isinstance(item, tuple):
            item = "/".join(item)
        if item in self.layers:
            return getattr(self, item)
        if item in self.properties:
            return getattr(self, item)
        if item in self.network_names:
            return getattr(self, item)
        return self._params[item]

    def __setitem__(self, item, val):
        self._params[item] = val

    def __add__(self, other):
        new = self[:]

        if isinstance(other, (ModelParams, dict)):
            new._params.update(other)
        else:
            raise ModelParamsException(f"Can't add {type(other)} to ModelParams")

        return new

    def __str__(self):
        return pformat(self._params)


def get_params_from_npz_path(params_path):
    return ModelParams(hailo_np_load(params_path))


def get_param_key(layer, param, io_index=0):
    if not isinstance(param, str) and hasattr(param, "__iter__"):
        param = "/".join(param)
    return f"{layer}/{param}:{io_index}"
