from abc import ABC, abstractmethod

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LAYERS_KEY, PreQuantizationFeature
from hailo_sdk_client.sdk_backend.script_parser.commands import SupportedCommands
from hailo_sdk_client.sdk_backend.script_parser.model_optimization_commands import FeatureOptimizationCommand
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import QuantizationScriptParserException


class PreQuantizationOptimizationCommand(FeatureOptimizationCommand, ABC):
    POST_FUSER_COMMAND = False

    def __init__(self, feature, cmd_str, cmd_line, kwargs):
        super().__init__(SupportedCommands.PRE_QUANTIZATION_OPTIMIZATION, feature, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_tokens(cls, tokens, cmd_str, cmd_line):
        feature_value, kwargs = cls._build_kwargs(tokens.function_args)
        feature = PreQuantizationFeature(feature_value)
        sub_commands_classes = {
            PreQuantizationFeature.equalization: EqualizationSubCommand,
            PreQuantizationFeature.activation_clipping: ActivationClippingSubCommand,
            PreQuantizationFeature.weights_clipping: WeightsClippingSubCommand,
            PreQuantizationFeature.se_optimization: SEOptimizationSubCommand,
            PreQuantizationFeature.dead_channels_removal: DeadChannelsRemovalSubCommand,
            PreQuantizationFeature.dead_layers_removal: DeadLayersRemovalSubCommand,
            PreQuantizationFeature.zero_static_channels: ZeroStaticChannelsSubCommand,
            PreQuantizationFeature.ew_add_fusing: EWAddFusingSubCommand,
            PreQuantizationFeature.layer_decomposition: LayerDecompositionCommand,
            PreQuantizationFeature.matmul_decomposition: MatmulDecompositionCommand,
            PreQuantizationFeature.smart_softmax_stats: SmartSoftmaxStatsSubCommand,
            PreQuantizationFeature.defuse: DefuseSubCommand,
            PreQuantizationFeature.resolution_reduction: ResolutionReductionSubCommand,
            PreQuantizationFeature.global_avgpool_reduction: GlobalAvgpoolReductionSubCommand,
            PreQuantizationFeature.add_shortcut_layer: AddShortCutLayerSubCommand,
            PreQuantizationFeature.layer_norm_decomposition: LayerNormDecompositionSubCommand,
            PreQuantizationFeature.matmul_correction: MatmulCorrectionSubCommand,
            PreQuantizationFeature.matmul_equalization: MatmulEqualizationSubCommand,
            PreQuantizationFeature.split_ew_mult_by_bit_significance: SplitEWMultByBitSignificanceSubCommand,
            PreQuantizationFeature.split_fused_activation: SplitFusedActivationCommand,
            PreQuantizationFeature.use_prequantized_weights: UsePrequantizedWeightsSubCommand,
            PreQuantizationFeature.switch_concat_with_add: SwitchConcatWithAddCommand,
            PreQuantizationFeature.quarot: QuaRotSubCommand,
            PreQuantizationFeature.conv_a16_w4: ConvA16W4SubCommand,
            PreQuantizationFeature.conv_decomposition: ConvDecompositionSubCommand,
        }
        return sub_commands_classes[feature].from_kwargs(kwargs, cmd_str, cmd_line)

    @abstractmethod
    def from_kwargs(self, kwargs, cmd_str, cmd_line):
        pass

    @classmethod
    def get_defuse_type(cls, tokens):
        value = tokens.function_args[-1]
        if isinstance(value, dict):
            value = next(iter(value.values()))
        return f"{value.lower()}_defuse"


class EqualizationSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.equalization, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True) if self.has_layers() else self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class ActivationClippingSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.activation_clipping, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class AddShortCutLayerSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.add_shortcut_layer, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class MatmulCorrectionSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.matmul_correction, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class MatmulEqualizationSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.matmul_equalization, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class UsePrequantizedWeightsSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.use_prequantized_weights, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class ConvDecompositionSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.conv_decomposition, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class ResolutionReductionSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.resolution_reduction, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True) if self.has_layers() else self._export_global_config()

    def validate_command(self, layers_scope_from_hn):
        pass

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class DefuseSubCommand(PreQuantizationOptimizationCommand, ABC):
    POST_FUSER_COMMAND = True

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.defuse, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")


class WeightsClippingSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.weights_clipping, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class SEOptimizationSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = True

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.se_optimization, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get("layers")
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args["layers"] = list(expanded_glob.keys())

    def export(self):
        return self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, "layers", force=force)


class ZeroStaticChannelsSubCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.zero_static_channels, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True) if self.has_layers() else self._export_global_config()

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class DeadChannelsRemovalSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = True

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.dead_channels_removal, cmd_str, cmd_line, kwargs)

    def export(self):
        return self._export_global_config()

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def add_scope(self, scope_names, force=False):
        pass


class DeadLayersRemovalSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.dead_layers_removal, cmd_str, cmd_line, kwargs)

    def export(self):
        return self._export_global_config()

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def add_scope(self, scope_names, force=False):
        pass


class LayerNormDecompositionSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.layer_norm_decomposition, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def add_scope(self, scope_names, force=False):
        pass

    def export(self):
        return self._export_global_config()


class EWAddFusingSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.ew_add_fusing, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def add_scope(self, scope_names, force=False):
        pass

    def export(self):
        return self._export_global_config()


class SmartSoftmaxStatsSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.smart_softmax_stats, cmd_str, cmd_line, kwargs)

    def export(self):
        return self._export_global_config()

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def add_scope(self, scope_names, force=False):
        pass


class LayerDecompositionCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.layer_decomposition, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class MatmulDecompositionCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.matmul_decomposition, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class GlobalAvgpoolReductionSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.global_avgpool_reduction, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")


class SplitEWMultByBitSignificanceSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.split_ew_mult_by_bit_significance, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")


class SplitFusedActivationCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = True

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.split_fused_activation, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script is not None:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        return self._export_per_layer_config(new_format=True)

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class QuaRotSubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = False

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.quarot, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def export(self):
        return self._export_global_config()

    def add_scope(self, scope_names, force=False):
        pass


class SwitchConcatWithAddCommand(PreQuantizationOptimizationCommand):
    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.switch_concat_with_add, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        export = self._export_per_layer_config(new_format=True)
        return export

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)


class ConvA16W4SubCommand(PreQuantizationOptimizationCommand):
    POST_FUSER_COMMAND = True

    def __init__(self, cmd_str, cmd_line, kwargs):
        super().__init__(PreQuantizationFeature.conv_a16_w4, cmd_str, cmd_line, kwargs)

    @classmethod
    def from_kwargs(cls, kwargs, cmd_str, cmd_line):
        return cls(cmd_str, cmd_line, kwargs)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        layers_from_script = self.function_args.get(LAYERS_KEY)
        if layers_from_script:
            expanded_glob = self._expand_glob_for_key(layers_from_script, layers_scope_from_hn, net_scopes)
            self.function_args[LAYERS_KEY] = expanded_glob

    def export(self):
        export = self._export_per_layer_config(new_format=True)
        return export

    def validate_command(self, layers_scope_from_hn):
        if LAYERS_KEY not in self.function_args:
            raise QuantizationScriptParserException(f"field required '{LAYERS_KEY}'")

    def add_scope(self, scope_names, force=False):
        self._add_scope_to_key(scope_names, LAYERS_KEY, force=force)
