#!/usr/bin/env python

import copy
from enum import Enum

from hailo_sdk_common.hailo_nn.exceptions import ToolsParamsError
from hailo_sdk_common.logger.logger import DeprecationVersion, default_logger


def param_val_str(param_v):
    if isinstance(param_v, list):
        return str([param_val_str(elem) for elem in param_v]).replace("'", "")
    elif isinstance(param_v, Enum):
        return param_v.value
    return str(param_v)


def param_str(param_k, param_v):
    should_hide = isinstance(param_v, list) and len(param_v) == 0
    if should_hide:
        return ""
    return f"{param_k}={param_val_str(param_v)}"


class ToolsParams:
    DEFAULT_PARAMS = {}
    ALLOWED_VALUES = {}
    VALUE_CONVERTERS = {}
    POLICY_TYPES = {}
    DEPRECATED_PARAMS = []
    BACKWARD_COMP_PARAMS = {}

    def __init__(self):
        self._params = {}
        self._default_params = copy.deepcopy(self.DEFAULT_PARAMS)

        self._logger = default_logger()

    def __str__(self):
        return ", ".join([param_str(param_k, param_v) for param_k, param_v in sorted(self.get().items())])

    @staticmethod
    def get_param_value_from_policy(value, policies):
        if value in policies:
            return policies[value]
        elif value in [True, "True", "true"]:
            return policies["enabled"]
        elif value in [False, "False", "false"]:
            return policies["disabled"]
        return None

    def set_param_with_policy(self, key, value, policies):
        value = type(self).get_param_value_from_policy(value, policies)
        if not value:
            self._logger.warning(f"Could not set unknown policy {value} for policies {policies!s}")
            value = self._params.get(key)
        return value

    def set_default_params(self):
        self._params = copy.deepcopy(self._default_params)

    def get_default_params(self):
        return self._default_params

    def override_params_from_kwargs(self, **kwargs):
        for kwarg in kwargs.items():
            if kwarg[1]:
                self._params[kwarg[0]] = kwarg[1]

    def clear(self):
        self._params.clear()

    def get(self, key=None):
        if key is not None:
            return self._params[key]
        else:
            return self._params

    def values_are_allowed(self, key, value):
        allowed = type(self).ALLOWED_VALUES.get(key, True)
        if key in type(self).POLICY_TYPES:
            allowed.extend(allowed_str_for_bool())
        return allowed

    def convert_string_to_value(self, key, value):
        if key in type(self).POLICY_TYPES and value in allowed_str_for_bool():
            return type(self).get_param_value_from_policy(value, type(self).POLICY_TYPES[key])
        if key in type(self).VALUE_CONVERTERS:
            return type(self).VALUE_CONVERTERS[key](value)
        return value

    def set(self, input_dict):
        for key, value in input_dict.items():
            if key in self.BACKWARD_COMP_PARAMS:
                # TODO: https://hailotech.atlassian.net/browse/SDK-34328
                self._logger.deprecation_warning(
                    f"'{key}' param was changed to '{self.BACKWARD_COMP_PARAMS[key]}'",
                    DeprecationVersion.FUTURE,
                )
                key = self.BACKWARD_COMP_PARAMS[key]
            if key in self.DEFAULT_PARAMS:
                if not self.values_are_allowed(key, value):
                    raise ToolsParamsError(f"Could not set parameter {key} with value {value}")
                value = self.convert_string_to_value(key, value)
                self._params[key] = value
            elif key in self.DEPRECATED_PARAMS:
                # TODO: https://hailotech.atlassian.net/browse/SDK-34328
                self._logger.deprecation_warning(
                    f"Ignored deprecated param {self.__class__.__name__} {key}",
                    DeprecationVersion.FUTURE,
                )
            else:
                raise ToolsParamsError(f"Could not set unknown {self.__class__.__name__} {key}")

    def set_pb_field(self, pb, field_name, type_to_pb=None):
        if field_name in self._params:
            value = type_to_pb[self._params[field_name]] if type_to_pb else self._params[field_name]
            setattr(pb, field_name, value)

    def set_auto_pb_field(self, pb, pb_wrapper, field_name, type_to_pb=None):
        if field_name in self._params:
            value = type_to_pb[self._params[field_name].val()] if type_to_pb else self._params[field_name].val()
            field_policy = self._params[field_name].policy()

            getattr(pb, field_name).policy = pb_wrapper.AUTO_VARIABLE_POLICY_TO_PB[field_policy]
            getattr(pb, field_name).val = value

    def extend_pb_field(self, pb, field_name, type_to_pb=None):
        if field_name in self._params:
            value = [type_to_pb[val] for val in self._params[field_name]] if type_to_pb else self._params[field_name]
            getattr(pb, field_name).extend(value)

    def get_pb_field(self, pb, field_name, pb_to_type=None):
        if pb.HasField(field_name):
            value = getattr(pb, field_name)
            if pb_to_type:
                value = pb_to_type[value]
            self._params[field_name] = value

    def get_auto_pb_field(self, pb, field_name, field_type):
        if pb.HasField(field_name):
            auto_var = getattr(pb, field_name)
            self._params[field_name] = field_type(auto_var.val)

    def get_pb_list_field(self, pb, field_name, pb_to_type=None):
        param_val = getattr(pb, field_name)
        if pb_to_type:
            param_val = [pb_to_type[v] for v in param_val]
        if len(param_val) > 0:
            self._params[field_name] = param_val

    def to_pb(self, pb, pb_wrapper):
        raise NotImplementedError

    def from_pb(self, pb, pb_wrapper):
        raise NotImplementedError


def convert_to_int(val):
    return int(val)


def convert_to_float(val):
    return float(val)


def convert_str_to_bool(val):
    return val in ["true", "True", True]


def convert_str_to_enumerate(enum):
    return lambda val: enum(val)


def convert_list_str_to_list_enumerate(enum):
    return lambda vals: [enum(val) for val in vals]


def convert_value_str_to_list_enumerate(enum):
    return lambda vals: [enum(val) for val in vals] if isinstance(vals, list) else [enum(vals)]


def allowed_str_for_bool():
    return ["True", "true", "False", "false", True, False]


def allowed_str_for_enumerate(enum):
    return [enum_val.value for enum_val in list(enum)]


def convert_time_to_int(val):
    time_multiplier = {
        "s": 1,
        "m": 60,  # minutes
        "h": 60 * 60,  # hours
        "d": 60 * 60 * 24,  # days
    }
    if val and type(val) in (float, int, str):
        return convert_to_int(val)
    elif isinstance(val, list) and len(val) == 2:
        return convert_to_int(val[0]) * time_multiplier[val[1]]
    else:
        raise ValueError('Timeout value: "{}" is illegal. should be number + (m|h|d)')


class AutoVariablePolicy(Enum):
    AUTOMATIC = "automatic"
    MANUAL = "manual"


class AutoInt:
    def __init__(self, val):
        if val == "automatic":
            self._policy = AutoVariablePolicy.AUTOMATIC
            self._val = 0  # uninitialized
        elif type(val) in [int, float]:
            self._policy = AutoVariablePolicy.MANUAL
            self._val = convert_to_int(val)
        elif isinstance(val, str):
            self._policy = AutoVariablePolicy.MANUAL
            self._val = int(val)
        elif isinstance(val, AutoInt):
            self._policy = val.policy()
            self._val = val.val()
        else:
            raise ValueError("Value must be a number or 'automatic'")

    def __str__(self):
        return str(self._val)

    def policy(self):
        return self._policy

    def val(self):
        return self._val


def convert_to_auto_int(val):
    return AutoInt(val)


def convert_time_to_auto_int(val):
    return AutoInt(convert_time_to_int(val))
