from copy import deepcopy
from enum import Enum
from typing import List, Tuple

from pydantic.v1 import BaseModel
from pydantic.v1.utils import deep_update

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import (
    ModelOptimizationConfig,
    ModelOptimizationFlavor,
    ModelOptimizationLayerConfig,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import CommandMeta
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerConfigBaseModel,
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    RECOMMENDED_DATASET_SIZE,
    GPUAvailabilityMode,
    ModelOptimizationCommand,
    OptimizationTarget,
)
from hailo_model_optimization.acceleras.utils.dataset_util import DatasetContianer
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.acceleras.utils.tf_utils import get_gpu_availability_mode, get_tf_dataset_length
from hailo_model_optimization.tools.simple_alls_parser import CommandInfo, parse_model_script


class LayersFieldType(Enum):
    NESTED_LAYERS_CONFIG = "nested"
    FLAT_LAYERS_CONFIG = "flat"
    MAIN_LAYERS_CONFIG = "main"


# Using pydantic and not dataclass for serialization
class OptimizationFlavorsInfo(BaseModel):
    optimization_level: int
    compression_level: int
    # additional debug info that will be added seperately
    # has_gpu: Optional[bool] = None
    # parameters_count: Optional[int] = None
    # dataset_length: Optional[int] = None


class MOScriptParser:
    """
    This class parses a model script for model optimization.
    It ignores unrelated commands (e.g., compiler command, allocation commands)

    Args:
        model_script: The model script as a pythonic string
        hn_default: ModelOptimizationConfig as dict with default values for layers.
            Currently it mostly includes values for translation config and precision config
            It hasn't been removed yet because many compiler tests specify the precision
            config through the hn instead of a model script. I hope it will be removed in the near future.
            # TODO: remove this field once the tests are fixed
        data: data as received by the optimize method. Used to determine the default optimization flavor
        data_type: data_type as received by the optimize method
        model_size: parameters count of the model. Used to determine the default compression flavor
        logger: logger

    """

    def __init__(
        self,
        model_script: str,
        hn_defaults,
        data_continer: DatasetContianer,
        model_size,
        logger,
        optimization_target,
    ) -> None:
        self._raw_script = model_script
        self._data_continer = data_continer
        self._logger = logger if logger is not None else default_logger()
        self._hn_defaults = hn_defaults
        self._model_size = model_size
        self._mo_config: ModelOptimizationConfig = None
        self.results = None
        self._target = OptimizationTarget(optimization_target)

    def run(self) -> ModelOptimizationConfig:
        flavors_commands, mo_commands = self._parse_script_to_commands()
        default_cfg = self._parse_flavor_to_default_cfg(flavors_commands)
        default_cfg.setdefault("precision_config", dict())
        default_cfg["precision_config"].setdefault("target", self._target)

        if self._hn_defaults:
            self._logger.info("Model received quantization params from the hn")
        default_cfg = deep_update(default_cfg, self._hn_defaults)

        dict_config = self._parse_commands_to_dict(mo_commands, default_cfg)
        mo_config = ModelOptimizationConfig(**dict_config)
        return mo_config

    def _parse_script_to_commands(self):
        """
        Parse the script to list of CommandInfo objects.
        Filters the mo commands and flavors commands separatly.
        """
        parsed_commands = parse_model_script(self._raw_script)
        parsed_commands = self._force_normalization_input_range(parsed_commands)
        flavors_commands, mo_commands = self._filter_mo_commands(parsed_commands)
        return flavors_commands, mo_commands

    def _parse_flavor_to_default_cfg(self, flavors_commands):
        """
        Prase flavors commands to get the default configuration of the model.
        """
        gpu_info = get_gpu_availability_mode()
        data_length = get_tf_dataset_length(
            self._data_continer.data,
            self._data_continer.data_type,
            RECOMMENDED_DATASET_SIZE,
            gpu_info.gpu_availability,
        )
        has_gpu = gpu_info.gpu_availability != GPUAvailabilityMode.NOT_AVAILABLE
        kwargs = {}
        for command in flavors_commands:
            if command.args:
                raise ValueError
            kwargs.update(command.kwargs)
        mo_flavor = ModelOptimizationFlavor(**kwargs)
        # TODO: get a parameters count estimation
        model_size = 30e6 if self._model_size is None else self._model_size

        default_cfg = mo_flavor.get_flavor_config(has_gpu, data_length, model_size, self._target, self._logger)
        self.results = OptimizationFlavorsInfo(
            optimization_level=mo_flavor.optimization_level,
            compression_level=mo_flavor.compression_level,
            # has_gpu=has_gpu,
            # parameters_count=self._model_size,
            dataset_length=data_length,
        )
        return default_cfg

    def _parse_commands_to_dict(self, mo_commands, default_cfg):
        """
        Parse the CommandInfo object of the mo_commands into a formatted dict the the mo_config
        """
        config = deepcopy(default_cfg) if default_cfg is not None else dict()
        command_handlers = {
            ModelOptimizationCommand.model_optimization_config: self._handle_feature_based_command,
            ModelOptimizationCommand.pre_quantization_optimization: self._handle_feature_based_command,
            ModelOptimizationCommand.post_quantization_optimization: self._handle_feature_based_command,
            ModelOptimizationCommand.quantization_param: self._handle_quantization_param_command,
            ModelOptimizationCommand.compression_params: self._handle_compression_param_command,
        }
        for command_info in mo_commands:
            command = ModelOptimizationCommand(command_info.command)
            command_handler = command_handlers[command]
            new_config = command_handler(command_info)
            config = deep_update(config, new_config)
        return config

    def _force_normalization_input_range(self, parsed_commands):
        """
        If we have a normalization command, we want to force the input range of the quantization.
        This is done by adding a quantization_param command with the input_normalization flag.
        """
        parsed_commands_extended = []
        for command_info in parsed_commands:
            parsed_commands_extended.append(command_info)
            if isinstance(command_info, CommandInfo) and command_info.command == "normalization":
                inputs_layers = command_info.args[2] if len(command_info.args) == 3 else "{*}"
                parsed_commands_extended.append(
                    CommandInfo(
                        loc=command_info.loc,
                        length=command_info.length,
                        command=ModelOptimizationCommand.quantization_param.value,
                        args=[inputs_layers],
                        kwargs={"input_normalization": "enabled"},
                    )
                )
        return parsed_commands_extended

    def _filter_mo_commands(self, parsed_commands) -> Tuple[List[CommandInfo], List[CommandInfo]]:
        """
        Filter mo_commands and flavor commands
        Returns:
            Tuple[List[mo_command], List[flavor_command]]
        """
        flavors_commands = []
        mo_commands = []
        for command_info in parsed_commands:
            if not isinstance(command_info, CommandInfo) or not self.is_in_enum(
                command_info.command,
                ModelOptimizationCommand,
            ):
                continue
            command = ModelOptimizationCommand(command_info.command)
            if not command_info.kwargs:
                command_info.kwargs = dict()
            if command == ModelOptimizationCommand.model_optimization_flavor:
                flavors_commands.append(command_info)
            else:
                mo_commands.append(command_info)
        return flavors_commands, mo_commands

    def _handle_feature_based_command(self, command: CommandInfo):
        """
        Convert a feature command into a formatted dict for mo_config
        """
        self._verify_command(command, args_count=1)
        feature = command.args[0]
        cfg = deepcopy(command.kwargs)
        meta_info = self._get_command_meta(command)

        if "layers" in cfg:
            layers_field_type = self._get_layers_field_type(command)
        else:
            layers_field_type = None

        if layers_field_type is None or layers_field_type == LayersFieldType.FLAT_LAYERS_CONFIG:
            # If layers is a primitive value in the config
            self._assign_meta_info(cfg, meta_info)
            cfg = self._handle_custom_layer_values(feature, cfg)
            formatted_cfg = {feature: cfg}
        else:
            layers = cfg.pop("layers")
            if not isinstance(layers, list):
                layers = [layers]
            self._assign_meta_info(cfg, meta_info)
            per_layer_cfg = dict()
            for layer in layers:
                per_layer_cfg[layer] = cfg
            if layers_field_type == LayersFieldType.MAIN_LAYERS_CONFIG:
                # if the config should be in the main context of the mo_config
                conf = {layer: {feature: cfg} for layer, cfg in per_layer_cfg.items()}
                formatted_cfg = {"layers": conf}
            elif layers_field_type == LayersFieldType.NESTED_LAYERS_CONFIG:
                # if the config should be in the feature context of the mo_config
                formatted_cfg = {feature: {"layers": per_layer_cfg}}
            else:
                raise ValueError(f"Unexpected value for layer field type {layers_field_type}")
        return formatted_cfg

    def _handle_custom_layer_values(self, feature, cfg):
        config_class = ModelOptimizationConfig.__fields__[feature].type_
        keys = config_class.flat_layers_fields().keys()
        for key in keys:
            value = cfg.get(key)
            if value is not None and not isinstance(value, list):
                cfg[key] = [value]
        return cfg

    def _handle_quantization_param_command(self, command: CommandInfo):
        """
        Convert a quantization_param command into a formatted dict for mo_config
        """
        self._verify_command(command, args_count=1)
        layers = command.args[0]
        if not isinstance(layers, list):
            layers = [layers]
        cfg = command.kwargs
        meta_info = self._get_command_meta(command)
        precision_formatted_cfg = dict()
        translation_formatted_cfg = dict()
        for key, value in cfg.items():
            if key in LayerPrecisionConfig.keys():
                prec_cfg = {key: value, "meta": {key: meta_info}}
                precision_formatted_cfg = deep_update(precision_formatted_cfg, prec_cfg)
            elif key in LayerTranslationConfig.keys():
                translation_cfg = {key: value, "meta": {key: meta_info}}
                translation_formatted_cfg = deep_update(translation_formatted_cfg, translation_cfg)
            else:
                raise ValueError(f'Unexpected key - "{key}" in quantization_param command')
        formatted_cfg = dict()
        for layer in layers:
            item_cfg = {
                "precision_config": {"layers": {layer: precision_formatted_cfg}},
                "translation_config": {"layers": {layer: translation_formatted_cfg}},
            }
            formatted_cfg = deep_update(formatted_cfg, item_cfg)
        return formatted_cfg

    def _handle_compression_param_command(self, command: CommandInfo):
        # TODO: add deprecation warning
        self._verify_command(command, args_count=0)
        feature = command.command
        cfg = deepcopy(command.kwargs)
        meta_info = self._get_command_meta(command)
        self._assign_meta_info(cfg, meta_info)
        formatted_cfg = {feature: cfg}
        return formatted_cfg

    def _verify_command(self, command: CommandInfo, args_count):
        """
        Verify misc fields of command
        """
        if len(command.args) != args_count:
            raise ValueError(f"Command {command.command} recieved the following args: {command.args}. Expected 1")
        if command.return_val:
            raise ValueError(f"Command {command.command} did not expected return val, recieved {command.return_val}")
        if command.command_object:
            raise ValueError(f"Command {command.command} did not expected context, recieved {command.command_object}")

    def _get_layers_field_type(self, command_info: CommandInfo) -> LayersFieldType:
        """
        Get the type of a layers field in a specific command.
        The 'layers' keyword can be represented in 3 different ways in the mo_config:
        - Flat represention, simple list of layers
        - Main -> layers -> config, When the config is under the main context of the config (will be removed in the future)
        - Nested, feature -> layers -> config when the config is under the relevant feature
        """
        feature = command_info.args[0]
        if feature in ModelOptimizationLayerConfig.__fields__:
            return LayersFieldType.MAIN_LAYERS_CONFIG
        if feature not in ModelOptimizationConfig.__fields__:
            raise ValueError(f"{feature} not in mo_config")

        feature_config_type = ModelOptimizationConfig.__fields__[feature].type_
        layers_field = feature_config_type.__fields__.get("layers")
        if layers_field.type_ is str:
            return LayersFieldType.FLAT_LAYERS_CONFIG
        elif issubclass(layers_field.type_, LayerConfigBaseModel):
            return LayersFieldType.NESTED_LAYERS_CONFIG
        else:
            raise ValueError(f"layers key was not found in mo_config for {feature}")

    def _get_command_meta(self, command_info: CommandInfo):
        if not self._raw_script:
            return None
        # TODO: alternatively, reconstruct from fields?
        command_string = self._raw_script[command_info.loc : command_info.loc + command_info.length]
        command_line = self._raw_script[: command_info.loc].count("\n") + 1
        return CommandMeta(command_line, command_string, False)

    def _assign_meta_info(self, cfg, meta_info):
        if meta_info is None:
            return
        cfg.setdefault("meta", dict())
        for key in cfg:
            if key == "meta":
                continue
            cfg["meta"][key] = meta_info

    @staticmethod
    def is_in_enum(value, enum_class):
        return value in enum_class._value2member_map_
