#!/usr/bin/env python

from enum import Enum

from hailo_sdk_common.hailo_nn.tools_params import (
    AutoInt,
    ToolsParams,
    allowed_str_for_enumerate,
    convert_str_to_bool,
    convert_str_to_enumerate,
    convert_to_auto_int,
    convert_to_float,
)
from hailo_sdk_common.logger.logger import DeprecationVersion


class ContextSwitchMode(Enum):
    MANUAL = "manual"
    AUTOMATIC = "automatic"
    ALLOWED = "allowed"
    ENABLED = "enabled"
    DISABLED = "disabled"


class Partitioner(Enum):
    SLOTH = "sloth"
    SLOTTER = "slotter"


class ToposortMode(Enum):
    DFS = "dfs"
    DEPTHWISE = "depthwise"


class ContextSwitchParams(ToolsParams):
    DEFAULT_PARAMS = {
        "mode": ContextSwitchMode.ALLOWED,
        "partitioner": Partitioner.SLOTTER,
        "max_utilization": 1,
        "max_control_utilization": 0.3,
        "max_compute_utilization": 0.3,
        "max_compute_16bit_utilization": 0.3,
        "max_memory_utilization": 0.3,
        "max_input_aligner_utilization": 0.7,
        "max_apu_utilization": 0.7,
        "goal_network_control_utilization": 1,
        "allow_auto_merge_in_multicontext": False,
        "slotter_chances": AutoInt("automatic"),
        "toposort_mode": ToposortMode.DEPTHWISE,
    }
    ALLOWED_VALUES = {
        "mode": allowed_str_for_enumerate(ContextSwitchMode),
        "partitioner": allowed_str_for_enumerate(Partitioner),
        "toposort_mode": allowed_str_for_enumerate(ToposortMode),
    }
    VALUE_CONVERTERS = {
        "mode": convert_str_to_enumerate(ContextSwitchMode),
        "partitioner": convert_str_to_enumerate(Partitioner),
        "max_utilization": convert_to_float,
        "max_control_utilization": convert_to_float,
        "max_compute_utilization": convert_to_float,
        "max_compute_16bit_utilization": convert_to_float,
        "max_memory_utilization": convert_to_float,
        "max_input_aligner_utilization": convert_to_float,
        "max_apu_utilization": convert_to_float,
        "goal_network_control_utilization": convert_to_float,
        "allow_auto_merge_in_multicontext": convert_str_to_bool,
        "slotter_chances": convert_to_auto_int,
        "toposort_mode": convert_str_to_enumerate(ToposortMode),
    }
    DEPRECATED_MODES = {"automatic": "enabled"}
    DEPRECATED_UTIL_PARAMS = [
        "goal_network_compute_utilization",
        "goal_network_memory_utilization",
        "goal_network_weights_utilization",
        "goal_network_input_aligner_utilization",
        "goal_network_apu_utilization",
    ]

    def __init__(self):
        super().__init__()
        self.set_default_params()

    def set(self, input_dict):
        if "mode" in input_dict:
            # TODO: https://hailotech.atlassian.net/browse/SDK-34319
            if self.VALUE_CONVERTERS["mode"](input_dict["mode"]) == ContextSwitchMode.MANUAL:
                self._logger.deprecation_warning(
                    f"'{input_dict['mode']}' context switch mode will be deprecated. "
                    f"Please use 'allowed', 'enabled' or 'disabled' instead.",
                    DeprecationVersion.FUTURE,
                )

            if input_dict["mode"] in self.DEPRECATED_MODES:
                # TODO: https://hailotech.atlassian.net/browse/SDK-34319
                self._logger.deprecation_warning(
                    f"'{input_dict['mode']}' context switch mode is deprecated, changing "
                    f"it to '{self.DEPRECATED_MODES[input_dict['mode']]}'",
                    DeprecationVersion.FUTURE,
                )
                input_dict["mode"] = self.DEPRECATED_MODES[input_dict["mode"]]
        for param in self.DEPRECATED_UTIL_PARAMS:
            if param in input_dict:
                del input_dict[param]
                # TODO: https://hailotech.atlassian.net/browse/SDK-40692
                self._logger.deprecation_warning(f"{param} is deprecated, reducing it.", DeprecationVersion.FUTURE)

        super().set(input_dict)

    def to_pb(self, pb, pb_wrapper):
        self.set_pb_field(pb, "mode", pb_wrapper.CONTEXT_SWITCH_MODE_TYPE_TO_PB)
        self.set_pb_field(pb, "partitioner", pb_wrapper.CONTEXT_SWITCH_PARTITIONER_TYPE_TO_PB)
        self.set_pb_field(pb, "max_utilization")
        self.set_pb_field(pb, "max_control_utilization")
        self.set_pb_field(pb, "max_compute_utilization")
        self.set_pb_field(pb, "max_compute_16bit_utilization")
        self.set_pb_field(pb, "max_memory_utilization")
        self.set_pb_field(pb, "max_input_aligner_utilization")
        self.set_pb_field(pb, "max_apu_utilization")
        self.set_pb_field(pb, "goal_network_control_utilization")
        self.set_pb_field(pb, "allow_auto_merge_in_multicontext")
        self.set_auto_pb_field(pb, pb_wrapper, "slotter_chances")
        self.set_pb_field(pb, "toposort_mode", pb_wrapper.TOPOSORT_MODE_TYPE_TO_PB)

    def from_pb(self, pb, pb_wrapper):
        self.get_pb_field(pb, "mode", pb_wrapper.CONTEXT_SWITCH_MODE_PB_TO_TYPE)
        self.get_pb_field(pb, "partitioner", pb_wrapper.CONTEXT_SWITCH_PARTITIONER_PB_TO_TYPE)
        self.get_pb_field(pb, "max_utilization")
        self.get_pb_field(pb, "max_control_utilization")
        self.get_pb_field(pb, "max_compute_utilization")
        self.get_pb_field(pb, "max_compute_16bit_utilization")
        self.get_pb_field(pb, "max_memory_utilization")
        self.get_pb_field(pb, "max_input_aligner_utilization")
        self.get_pb_field(pb, "max_apu_utilization")
        self.get_pb_field(pb, "goal_network_control_utilization")
        self.get_pb_field(pb, "allow_auto_merge_in_multicontext")
        self.get_auto_pb_field(pb, "slotter_chances", AutoInt)
        self.get_pb_field(pb, "toposort_mode", pb_wrapper.TOPOSORT_MODE_PB_TO_TYPE)
