from abc import ABC, abstractmethod

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LAYERS_KEY, PostQuantizationFeature
from hailo_sdk_client.sdk_backend.script_parser.commands import SupportedCommands
from hailo_sdk_client.sdk_backend.script_parser.model_optimization_commands import FeatureOptimizationCommand


class PostQuantizationOptimizationCommand(FeatureOptimizationCommand, ABC):
    def __init__(self, feature, cmd_str, cmd_line, kwargs):
        super().__init__(SupportedCommands.POST_QUANTIZATION_OPTIMIZATION, feature, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_tokens(cls, tokens, cmd_str, cmd_line):
        feature_value, kwargs = cls._build_kwargs(tokens.function_args)
        feature = PostQuantizationFeature(feature_value)
        sub_commands_classes = {
            PostQuantizationFeature.bias_correction: BiasCorrectionSubCommand,
            PostQuantizationFeature.adaround: AdaRoundSubCommand,
            PostQuantizationFeature.block_round_training: BlockRoundTrainingSubCommand,
            PostQuantizationFeature.finetune: FineTuneSubCommand,
            PostQuantizationFeature.mix_precision_search: MixPrecisionSearchSubCommand,
            PostQuantizationFeature.train_encoding: TrainEncodingSubCommand,
        }
        return sub_commands_classes[feature].from_kwargs(kwargs, cmd_str, cmd_line)

    @abstractmethod
    def from_kwargs(self, kwargs, cmd_str, cmd_line):
        pass


class BiasCorrectionSubCommand(PostQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PostQuantizationFeature.bias_correction, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True) if self.has_layers() else self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class AdaRoundSubCommand(PostQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PostQuantizationFeature.adaround, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True) if self.has_layers() else self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class MixPrecisionSearchSubCommand(PostQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PostQuantizationFeature.mix_precision_search, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config() if self.has_layers() else self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class BlockRoundTrainingSubCommand(PostQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PostQuantizationFeature.block_round_training, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config() if self.has_layers() else self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class FineTuneSubCommand(PostQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PostQuantizationFeature.finetune, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        freeze_layers_from_script = self.function_args.get("layers_to_freeze")
        if freeze_layers_from_script is not None:
            expanded_freeze_glob = self._expand_glob_for_key(
                freeze_layers_from_script,
                layers_scope_from_hn,
                net_scopes,
            )
            self.function_args["layers_to_freeze"] = list(expanded_freeze_glob.keys())

        loss_layers_from_script = self.function_args.get("loss_layer_names")
        if loss_layers_from_script is not None:
            expanded_loss_glob = self._expand_glob_for_key(loss_layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args["loss_layer_names"] = list(expanded_loss_glob.keys())

        native_layers_from_script = self.function_args.get("native_layers")
        if native_layers_from_script is not None:
            expanded_native_glob = self._expand_glob_for_key(
                native_layers_from_script,
                layers_scope_from_hn,
                net_scopes,
            )
            self.function_args["native_layers"] = list(expanded_native_glob.keys())

    def export(self):
        return self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, "layers_to_freeze", force=force)
        self._add_scope_to_key(scope_names, "loss_layer_names", force=force)
        self._add_scope_to_key(scope_names, "native_layers", force=force)


class TrainEncodingSubCommand(FineTuneSubCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super(FineTuneSubCommand, self).__init__(PostQuantizationFeature.train_encoding, cmd_str, cmd_line, kwargs)
