from abc import ABC
from typing import Dict, List, Tuple, Union

from pydantic.v1 import Field, root_validator, validator

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import (
    BaseConfigBaseModel,
    _value_to_str,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerActivationClippingConfig,
    LayerAdaRoundConfig,
    LayerAddShortcutConfig,
    LayerBiasCorrectionConfig,
    LayerConfigBaseModel,
    LayerConvA16W4Config,
    LayerConvDecompositionConfig,
    LayerDecompositionConfig,
    LayerDefuseConfig,
    LayerEqualizationConfig,
    LayerGlobalAvgpoolReductionConfig,
    LayerLoadQuantConfig,
    LayerMatmulCorrectionConfig,
    LayerMatmulDecompositionConfig,
    LayerMatmulEqualizationConfig,
    LayerNegExponentConfig,
    LayerPrecisionConfig,
    LayerResolutionReductionConfig,
    LayerSplitEWMultByBitSignificanceConfig,
    LayerSplitFusedActivationConfig,
    LayerSwtichConcatWithAddConfig,
    LayerTranslationConfig,
    LayerWeightsClippingConfig,
    LayerZeroStaticChannelsConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_BATCH_SIZE,
    DEFAULT_DATASET_SIZE,
    DEFAULT_EPOCHS,
    DEFAULT_LEARNING_RATE,
    DEFAULT_ZERO_STATIC_CHANNELS_EPSILON,
    MAX_NUM_REPEATS_ELTWISE,
    AdaRoundMode,
    BiasCorrectionPolicy,
    ComprecisionMetric,
    CompressionTypes,
    DeadChannelsRemovalPolicy,
    DeadLayersRemovalPolicy,
    DistributionStrategy,
    EqualizationMode,
    EqualizationPolicy,
    FeaturePolicy,
    FinetunePolicy,
    InfusibleEWAddType,
    LayerNormDecompositionMode,
    LayerNormMode,
    LossType,
    MetaArchType,
    MOConfigCommand,
    ModelOptimizationCommand,
    MultiOutputMetric,
    OptimizationTarget,
    Optimizer,
    PostQuantizationFeature,
    PreQuantizationFeature,
    QFTWriterMode,
    ResolutionReductionInterpolationMode,
    ScheduleType,
    SensitivitySearch,
    SEOptimizationMethod,
    SoftmaxBiasOptimizationAlgorithm,
    ThreeWayPolicy,
    TiledSqueezeAndExciteMode,
    WarmupStrategy,
)
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.algorithms.lat_utils.lat_utils import AnalysisMode


class FeatureConfigBaseModel(BaseConfigBaseModel, ABC):
    def to_cmd(self, exclude_defaults, layers_to_exclude=None):
        """
        Converts the feature's config to command.

        Returns
            The command as string. May be empty string if the config is empty.

        """
        as_dict = self.raw_dict(exclude_defaults)
        all_commands = []
        for field in self.nested_keys():
            del as_dict[field]
            existing_dicts = []  # dict is unhashable and I was too lazy to covert it to sorted tuples
            corresponding_layers = []
            for lname, layer_cfg in getattr(self, field).items():
                if layers_to_exclude and lname in layers_to_exclude:
                    continue
                data = layer_cfg.raw_dict()
                if data in existing_dicts:
                    corresponding_layers[existing_dicts.index(data)].append(lname)
                else:
                    existing_dicts.append(data)
                    corresponding_layers.append([lname])
            for layers, raw_dict in zip(corresponding_layers, existing_dicts):
                cmd = layer_cfg.to_cmd(raw_dict, layers)
                all_commands.append(cmd)
        command = self.get_command()
        feature = self.get_feature()
        kwargs = ", ".join(f"{k}={_value_to_str(v)}" for k, v in sorted(as_dict.items()))
        if kwargs:
            cmd = f"{command}({feature}, {kwargs})"
            all_commands.append(cmd)
        return "\n".join(all_commands)

    @classmethod
    def nested_keys(cls) -> set:
        nested_fields = set()
        for field in cls.__fields__:
            try:
                if issubclass(cls.__fields__[field].type_, LayerConfigBaseModel):
                    nested_fields.add(field)
            except TypeError:
                pass
        return nested_fields

    @classmethod
    def keys(cls):
        keys = super().keys()
        return keys - cls.nested_keys()

    @classmethod
    def flat_layers_fields(cls):
        """
        Parser helper methods, indicates which fields should be handles as layers list
        """
        return dict()

    def remove_layer_from_config(self, lname):
        if hasattr(self, "layers") and self.layers:
            if lname in self.layers:
                self.layers.pop(lname)
            # in case the name given is with scope but isn't in the config
            lname = lname.split("/", 1)[-1]
            if lname in self.layers:
                self.layers.pop(lname)


class GlobalEqualizationConfig(FeatureConfigBaseModel):
    """
    This sub-command allows configuring the global equalization behavior during the pre-quantization
    process, this command replaces the old equalize parameter from
    :func:`quantize() <hailo_sdk_client.runner.client_runner.ClientRunner.quantize>` API

    Example command:

    .. code-block::

        pre_quantization_optimization(equalization, policy=disabled)

    .. note::

        An in-depth explanation of the equalization algorithm - https://arxiv.org/pdf/1902.01917.pdf

    """

    policy: EqualizationPolicy = Field(
        EqualizationPolicy.enabled,
        description="Enable or disable the equalization algorithm",
    )
    mode: EqualizationMode = Field(EqualizationMode.min_based, description="Select a type of equalization algorithm")
    two_stage: bool = Field(True, description="Take into account the equiv-set outputs when calculating factors")
    max_activation_factor: int = Field(16, description="Set a max value for the activation factor")
    skip_sbn_layers: bool = Field(True)
    skip_8b_to_8b: bool = Field(False)
    skip_multi_source: bool = Field(True)
    equalize_inputs: ThreeWayPolicy = Field(
        ThreeWayPolicy.allowed,
        description="Enable or disable the equalization of the inputs",
    )
    equalize_outputs: ThreeWayPolicy = Field(
        ThreeWayPolicy.allowed,
        description="Enable or disable the equalization of the outputs",
    )
    layers: Dict[str, LayerEqualizationConfig] = Field(dict(), description="Per layer custom configuration")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.equalization.value

    @classmethod
    def _internal_keys(cls) -> set:
        eq_internal_keys = {
            "mode",
            "two_stage",
            "max_activation_factor",
            "skip_sbn_layers",
            "skip_8b_to_8b",
            "skip_multi_source",
            "equalize_inputs",
            "equalize_outputs",
        }
        return eq_internal_keys

    @validator("two_stage", always=True)
    def validate_two_stage(cls, v, values, **kwargs):
        if "mode" not in values:
            return v
        if values["mode"] != EqualizationMode.min_based:
            return False
        return v

    @validator("policy")
    def validate_explicit_policy(cls, v, **kwargs):
        valid_policies = {EqualizationPolicy.enabled, EqualizationPolicy.disabled}
        if v not in valid_policies:
            raise TypeError("value is not a valid enumeration member; permitted: 'enabled', 'disabled'")
        return v

    def info_config(self):
        default_logger().verbose("Equalization configuration:")
        default_logger().verbose(f"    Default policy: {self.policy.value}")
        default_logger().verbose(f"    Mode: {self.mode.value}")

    @classmethod
    def get_default(cls):
        return cls()


class GlobalLayerNormDecompositionConfig(FeatureConfigBaseModel):
    """
    This sub-command allows configuring norm_layers dec

    Example command:

    .. code-block::

        pre_quantization_optimization(layer_norm_decomposition, mode=nn_core)

    """

    mode: LayerNormMode = Field(LayerNormMode.auto, description="Select a type of decompostion mode")
    equalization: ThreeWayPolicy = Field(ThreeWayPolicy.allowed, description="Do equalization on the layer norm")
    nudging: bool = Field(True, description="do nudging to weights")
    group_nudging: bool = Field(True, description="do group_nudging to weights")
    square_12_bit: bool = Field(True, description="make square of the layer norm 12 bit")
    optimize_ew_mult: bool = Field(True, description="optimize_ew_mult")
    force_hw: OptimizationTarget = Field(None, description="should force hw for reserch")
    token_equalization: FeaturePolicy = Field(
        FeaturePolicy.disabled, description="Do online equalization over the tokens"
    )
    add_buffer_layer: FeaturePolicy = Field(
        FeaturePolicy.disabled, description="add buffer layer in case of token equalization"
    )
    bit_decomposition_mode: LayerNormDecompositionMode = Field(
        LayerNormDecompositionMode.auto,
        description="For 16bit precision_mode, determine whether to work in standard or decomposed 16bit layers.",
    )
    group_size_split: int = Field(
        512, description="Estimate for the summed elements in the post square nudging layer in."
    )
    force_group_size_split: bool = Field(False, description="force group size split")
    sum_range_factor: int = Field(
        2, description="Factor for the max range allowance for the integrated sum in the square."
    )
    square_reduce_sum_groups: int = Field(
        None,
        description="Explicit square sum group size value. if not provided sum_range_factor heuristic will be used",
    )
    eq_consumer: bool = Field(
        False, description="Add normalization layer before the layer norm as equalization consumer"
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.layer_norm_decomposition.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {
            "mode",
            "equalization",
            "square_12_bit",
            "nudging",
            "group_nudging",
            "optimize_ew_mult",
            "force_hw",
            "token_equalization",
            "add_buffer_layer",
            "bit_decomposition_mode",
            "group_size_split",
            "sum_range_factor",
            "force_group_size_split",
            "square_reduce_sum_groups",
        }

    @validator("mode")
    def validate_explicit_mode(cls, v, **kwargs):
        valid_policies = {LayerNormMode.auto, LayerNormMode.nn_core, LayerNormMode.ppu}
        if v not in valid_policies:
            raise TypeError("value is not a valid enumeration member; permitted: 'nn_core', 'ppu', 'auto'")
        return v

    @validator("force_hw")
    def validate_force_hw(cls, v, **kwargs):
        valid_policies = {OptimizationTarget.SAGE, OptimizationTarget.MERCURY, OptimizationTarget.PLUTO, None}
        if v not in valid_policies:
            raise TypeError("value is not a valid enumeration member; permitted: 'sage', 'mercury', 'pluto'")
        return v

    @classmethod
    def get_default(cls):
        return cls()


class GlobalDeadLayersRemovalConfig(FeatureConfigBaseModel):
    """
    This sub-command allows configuring the dead layers removal

    Example command:

    .. code-block::

        pre_quantization_optimization(dead_layers_removal, policy=disabled)

    """

    policy: DeadLayersRemovalPolicy = Field(
        DeadLayersRemovalPolicy.enabled,
        description="Enable or disable the dead layers removal algorithm",
    )
    threshold: float = Field(
        1e-5,
        description="Threshold to remove layer. The layer will be removed if "
        "all the weights (in their absolute value) in the layer are under this value",
    )
    validate_change: DeadLayersRemovalPolicy = Field(
        DeadLayersRemovalPolicy.enabled,
        description="iF enabled, the algorithm will validate that the removal of the layer "
        "by comparing the output of the network before and after the removal",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.dead_layers_removal.value

    @classmethod
    def _internal_keys(cls) -> set:
        dead_channels_removal_internal_keys = {
            "threshold",
        }
        return dead_channels_removal_internal_keys

    def info_config(self):
        default_logger().verbose("Dead Layers Removal Configuration:")
        default_logger().verbose(f"    Policy: {self.policy.value}")
        default_logger().verbose(f"    Threshold: {self.threshold.value}")

    @classmethod
    def get_default(cls):
        return cls()


class GlobalBiasCorrectionConfig(FeatureConfigBaseModel):
    """
    This sub-command allows configuring the global bias correction behavior during the post-quantization process,
    this command replaces the old ibc parameter from
    :func:`quantize() <hailo_sdk_client.runner.client_runner.ClientRunner.quantize>` API

    Example command:

    .. code-block::

        # This will enable the IBC during the post-quantization
        post_quantization_optimization(bias_correction, policy=enabled)

    .. note::

        An in-depth explanation of the IBC algorithm - https://arxiv.org/pdf/1906.03193.pdf

    .. note::

        Bias correction is recommended when the model contains small kernels or depth-wise layers
    """

    # TODO: add batch size, calibset size, and use_cache as internal
    # TODO: add info function
    policy: BiasCorrectionPolicy = Field(
        BiasCorrectionPolicy.disabled,
        description="Enable or disable the bias correction algorithm. When "
        "Optimization Level >= 1, could be enabled by the default policy.",
    )
    fast_ibc: bool = Field(
        True,
        description="Toggle between partial numeric and exact numeric emulations. (partial numeric is faster)",
    )
    batch_size: int = Field(None, gt=0, description="Batch size used during model inference for bias correction")
    calibset_size: int = Field(None, gt=0, description="Data items used for bias correction algorithm")
    cache_compression: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Enable or disable the compression of layer results when cached to disk.",
    )
    layers: Dict[str, LayerBiasCorrectionConfig] = Field(dict(), description="Per layer custom configuration")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.bias_correction.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {"fast_ibc", "calibset_size", "batch_size"}

    @validator("policy")
    def validate_explicit_policy(cls, v, **kwargs):
        valid_policies = {BiasCorrectionPolicy.enabled, BiasCorrectionPolicy.disabled}
        if v not in valid_policies:
            raise TypeError("value is not a valid enumeration member; permitted: 'enabled', 'disabled")
        return v

    @classmethod
    def get_default(cls):
        return cls()


class CompressionConfig(FeatureConfigBaseModel):
    """
    This command controls layers 4-bit and 16-bit quantization. In 4-bit mode, it reduces some layers' precision mode to a8_w4.
    The values (between 0 and 1 inclusive) represent how much of the total weight memory usage you want to optimize
    to 4bit. When the value is 1, all the weights will be set to 4bit, when 0, the weights won't be modified.
    The 16-bit mode is supported only when setting on the entire network (setting 16-bit value of 1) and without using 4-bit (setting 4-bit value to 0).

    Example command:

    .. code-block::

        # Optimize 30% of the total weights to use 4bit mode
        model_optimization_config(compression_params, auto_4bit_weights_ratio=0.3)

    .. note::

        If you manually set some layers' precision_mode using quantization_param,
        the optimization will take it into account, and won't set any weight back to 8bit

    .. note::

        If you set 16-bit quantization, all layers activations and weights are quantized using 16 bits. In this case, explicit configuration
        of layer bias mode is not allowed.
    """

    auto_4bit_weights_ratio: float = Field(
        0,
        ge=0,
        le=1,
        description="Set a ratio of the model's weights to reduce to 4bit",
    )

    auto_16bit_weights_ratio: float = Field(0, description="Set a ratio of the model's weights to reduce to 16bit")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.compression_params.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    @validator("auto_16bit_weights_ratio")
    def validate_16bit_config(cls, v, values):
        if v not in [0, 1]:
            raise ValueError("Ratios supported for auto_16bit_weights_ratio are currently 0 and 1")
        if v == 0:
            return v
        # v is 1
        auto_4bit_weights_ratio = values.get("auto_4bit_weights_ratio")
        if auto_4bit_weights_ratio > 0:
            raise ValueError("Parameter auto_16bit_weights_ratio=1 is only supported with auto_4bit_weights_ratio=0")
        return v


class GlobalsMOConfig(FeatureConfigBaseModel):
    """
    Model configuration during the quantization that didn't fit anywhere else...

    Example command:

    .. code-block::

        model_optimization_config(globals, gpu_policy=auto, max_elementwise_feed_repeat=2)

    """

    max_elementwise_feed_repeat: int = Field(MAX_NUM_REPEATS_ELTWISE, ge=1, le=16)
    sort_params: ThreeWayPolicy = Field(ThreeWayPolicy.allowed, description="Control sorting for weights in layers")
    output_16bit: ThreeWayPolicy = Field(ThreeWayPolicy.allowed, description="Control 16-bit output in optimization")
    output_encoding_vector: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Control output_encoding_vector in optimization",
    )
    input_encoding_vector: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Control input_encoding_vector in optimization",
    )
    multiproc_policy: ThreeWayPolicy = Field(
        ThreeWayPolicy.allowed,
        description="Select the multiprocessing policy in optimization flow",
    )
    gpu_policy: DistributionStrategy = Field(
        DistributionStrategy.AUTO,
        description="Choose which GPU policy should be used on each step. Set AUTO to optimize performance",
    )
    output_16bit_as_8bit: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Control 16-bit output as 8-bit in optimization",
    )

    deequalize: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Do deequalize of kernels",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.global_config.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {
            "gpu_policy",
            "max_elementwise_feed_repeat",
            "sort_params",
            "output_16bit",
            "output_encoding_vector",
            "input_encoding_vector",
            "output_16bit_as_8bit",
            "deequalize",
        }

    @classmethod
    def get_default(cls):
        return cls()


class CalibrationConfig(FeatureConfigBaseModel):
    """
    During the quantization process, the model will be inferred with small dataset for calibration purposes.
    The calibration can be configured here. (This replaces the calib_num_batch and batch_size arguments in
    :func:`quantize() <hailo_sdk_client.runner.client_runner.ClientRunner.quantize>` API)

    Example command:

    .. code-block::

        model_optimization_config(calibration, batch_size=4, calibset_size=128)

    """

    batch_size: int = Field(DEFAULT_BATCH_SIZE, gt=0, description="Batch size used during the calibration inference")
    calibset_size: int = Field(64, gt=0, description="Data items used during the calibration inference")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.calibration.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class AdaRoundConfig(FeatureConfigBaseModel):
    """
    Adaround algorithm optimizes layers' quantization by training the rounding of the kernel layer-by-layer.
    To enable it, use high optimization_level (>=3), or use the explicit command:

    .. code-block::

        post_quantization_optimization(adaround, policy=enabled)

    It is used by the highest optimization level to recover any degradation caused by quantization, and as such,
    it is time consuming and requires strong system in order to run.

    To reduce some of the memory usage of the algorithm, it is recommended to:

    - Ensure dali package is installed

        - For example: `pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110 nvidia-dali-tf-plugin-cuda110`
        - DALI is an external package which is being used by AdaRound algorithm to accelerate the running time (see warning raised during the run for more information)
    - Use a lower batch size

        - For example, using the alls command: `post_quantization_optimization(adaround, policy=enabled, batch_size=8)`
        - Lowering the batch size can reduce the RAM memory consumption but will increase the running time (default is 32)
    - Enabled/ disabled cache_compression

        - For example, the alls command: `post_quantization_optimization(adaround, cache_compression=enabled, policy=enabled)` enables cache compression.
        - Enables compression on the disk to reduce disk space usage at the expanse of increased running time (default is disabled).
    - Use smaller dataset_size

        - For example, using the alls command: `post_quantization_optimization(adaround, policy=enabled, dataset_size=256)`
        - Using a smaller dataset for Adaround would reduce the memory consumption but might affect the final accuracy (default is 1024)
    - Disable bias training

        - For example, using the alls command: `post_quantization_optimization(adaround, policy=enabled, train_bias=False)`
        - Disabling bias training can help to reduce running time but might affect the final accuracy (default is true)

    - Reduce the number of epochs

        - For example, using the alls command: `post_quantization_optimization(adaround, policy=enabled, epochs=100)`
        - Reducing the number of epochs can help to reduce the running time of the algorithm but might affect the final accuracy (default is 320)
    """

    policy: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Enable or disable the adaround algorithm. When Optimization Level >= 1, "
        "could be enabled by the default policy.",
    )
    learning_rate: float = Field(0.001, gt=0, description="Learning rate used for gradient descent by the optimizer.")
    batch_size: int = Field(32, gt=0, description="batch size of the ada round algorithm")
    dataset_size: int = Field(1024, gt=0, description="Data samples for adaptive round algorithm")
    epochs: int = Field(320, gt=0, description="Number of train epochs")
    warmup: float = Field(0.2, ge=0, le=1, description="Ratio of warmup epochs out of epochs")
    weight: float = Field(
        0.01,
        gt=0,
        description="Regularization weight; higher value emphasizes rounding cost over reconstruction loss (MSE).",
    )
    train_bias: bool = Field(
        True,
        description="Whether to train bias as well or not (will apply bias correction if layer is not trained)",
    )
    bias_correction_count: int = Field(64, description="Data count for bias correction")
    mode: AdaRoundMode = Field(AdaRoundMode.train_4bit, description="default train behavior")
    b_range: Tuple[float, float] = Field((20, 2), description="Max, min for temperature decay")
    decay_start: float = Field(0, ge=0, le=1, description="Ratio of round train without round regularization decay (b)")
    cache_compression: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Enable or disable the compression of layer results when cached to disk.",
    )
    shuffle: bool = Field(True, description="Shuffle the dataset during layer's train")
    seed: int = Field(None, description="Seed for dataset shuffle")
    eager: bool = Field(False, description="Train model in eager mode")
    log_samples: int = Field(50, description="How many samples from layer training should be saved")
    layers: Dict[str, LayerAdaRoundConfig] = Field(dict(), description="Per layer custom configuration")

    @classmethod
    def get_default(cls):
        return cls(policy=FeaturePolicy.disabled)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.adaround.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {"eager", "shuffle", "seed", "log_samples"}

    @validator("bias_correction_count", always=True)
    def validate_bias_correction_count(cls, v, values, **kwargs):
        dataset_size = values.get("dataset_size")
        if dataset_size is None:
            return v
        if dataset_size < v:
            raise ValueError("Data count for bias correction has to be less than or equal to dataset size")
        return v

    @classmethod
    def advanced_keys(cls):
        return {
            "b_range",
            "decay_start",
        }


class BlockRoundTrainingConfig(FeatureConfigBaseModel):
    """
    The BlockRound algorithm performs post-training quantization. BlockRound works in a similar way to AdaRound. The difference
    between the two is that BlockRound splits the model into blocks that are usually bigger than a single layer. The algorithm
    then trains each block separately. Doing so, it has the potential to save computation and running time.

    Example commands:
    To enable BlockRound, use the explicit command:
    .. code-block::

        post_quantization_optimization(block_round_training, policy=enabled)

    Note that as with AdaRound, the BlockRound algorithm is a time-consuming algorithm that requires a strong system to run.

    To reduce some of the memory usage of the algorithm, it is recommended to:

    - Ensure dali package is installed

        - For example: `pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110 nvidia-dali-tf-plugin-cuda110`
        - DALI is an external package which is being used by BlockRound algorithm to accelerate the running time (see warning raised during the run for more information)
    - Use a lower batch size

        - For example, using the alls command: `post_quantization_optimization(block_round_training, policy=enabled, batch_size=8)`
        - Lowering the batch size can reduce the RAM memory consumption but will increase the running time (default is 32)
    - Enabled/ disabled cache_compression

        - For example, the alls command: `post_quantization_optimization(block_round_training, cache_compression=enabled, policy=enabled)` enables cache compression.
        - Enables compression on the disk to reduce disk space usage at the expanse of increased running time (default is disabled).
    - Use smaller dataset_size

        - For example, using the alls command: `post_quantization_optimization(block_round_training, policy=enabled, dataset_size=256)`
        - Using a smaller dataset for BlockRound would reduce the memory consumption but might affect the final accuracy (default is 1024)
    - Disable bias training

        - For example, using the alls command: `post_quantization_optimization(block_round_training, policy=enabled, train_bias=False)`
        - Disabling bias training can help to reduce running time but might affect the final accuracy (default is true)

    - Reduce the number of epochs

        - For example, using the alls command: `post_quantization_optimization(block_round_training, policy=enabled, epochs=100)`
        - Reducing the number of epochs can help to reduce the running time of the algorithm but might affect the final accuracy (default is 320)
    """

    policy: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Enable or disable the blockround algorithm for training.",
    )
    learning_rate: float = Field(0.001, gt=0, description="Learning rate used for gradient descent by the optimizer.")
    batch_size: int = Field(32, gt=0, description="Batch size for infer and training.")
    dataset_size: int = Field(1024, gt=0, description="Number of data samples for infer and training.")
    epochs: int = Field(320, gt=0, description="Number epochs to train for.")
    warmup: float = Field(
        0.2,
        ge=0,
        le=1,
        description="Ratio out of epochs for which temperature (beta) is fixed at maximum value.",
    )
    weight: float = Field(
        0.01,
        gt=0,
        description="Regularization weight; higher value emphasizes rounding cost over reconstruction loss (MSE).",
    )
    train_bias: FeaturePolicy = Field(
        FeaturePolicy.enabled,
        description="Enable or disable training for the bias weights.",
    )
    b_range: Tuple[float, float] = Field(
        (20, 2),
        description="Temperature boundaries (max, min); in training, temperature (beta) decays linearly from the max value to the min value.",
    )
    decay_start: float = Field(0, ge=0, le=1, description="Ratio of round train without round regularization decay (b)")

    run_eagerly: FeaturePolicy = Field(FeaturePolicy.disabled, description="Enable or disable eager mode.")
    log_samples: int = Field(50, description="In training, it is the number of times to sample the block's metrics.")
    use_dali: FeaturePolicy = Field(
        FeaturePolicy.enabled,
        description="Enable or disable the use of NVIDIA DALI for data loading, saving, and augmentation.",
    )
    resolution: float = Field(
        2.0,
        description=(
            "Resolution less than 1 favors large communities and smaller than 1 favors smaller communities."
            "To approximate for AdaRound, use inf."
        ),
    )
    internal_encoding: FeaturePolicy = Field(
        FeaturePolicy.disabled,
        description="Enable or disable internal encodings between ops and layers.",
    )
    compression_type: CompressionTypes = Field(
        CompressionTypes.none,
        description="Compression type of block activations when cached to disk. Options are none, zlib, or gzip.",
    )
    device: str = Field("gpu", description="Choose between a cpu and a gpu (default is gpu).")
    device_id: int = Field(0, ge=0, description="ID of GPU used by the pipeline.")
    num_threads: int = Field(4, gt=0, description="Number of CPU threads used by the DALI pipeline.")

    @classmethod
    def get_default(cls):
        return cls(policy=FeaturePolicy.disabled)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.block_round_training.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {"eager", "shuffle", "seed", "log_samples"}

    @classmethod
    def is_internal_feature(cls):
        return True

    @classmethod
    def advanced_keys(cls):
        return {
            "b_range",
            "decay_start",
        }


class FineTuneConfig(FeatureConfigBaseModel):
    """
    This sub-command enabled knowledge distillation based fine-tuning of the quantized graph.

    Example commands:

    .. code-block::

        # enable fine-tune with default configuration
        post_quantization_optimization(finetune)

        # enable fine-tune with a larger dataset
        post_quantization_optimization(finetune, dataset_size=4096)
    """

    # only when the command is triggered
    policy: FinetunePolicy = Field(
        description="Enable or disable finetune training. When Optimization Level >= 1, "
        "could be enabled by the default policy.",
    )
    dataset_size: int = Field(
        DEFAULT_DATASET_SIZE,
        gt=0,
        description="Number of images used for training; Exception is thrown if the supplied "
        "calibration set data stream falls short of that.",
    )
    batch_size: int = Field(
        None,
        gt=0,
        description="Uses the calibration batch_size by default. Number of images used together in "
        "each training step; driven by GPU memory constraints (may need to be reduced "
        "to meet them) but also by the algorithmic impact opposite to that of "
        "learning_rate.",
    )
    epochs: int = Field(DEFAULT_EPOCHS, ge=0, description="Epochs of training")
    learning_rate: float = Field(
        None,
        description=f"The base learning rate used for the schedule calculation (e.g., starting"
        f" point for the decay). default value is `{DEFAULT_LEARNING_RATE} / 8 * "
        f"batch_size`. Main parameter to experiment with; start from small values "
        f"for architectures substantially different from well-performing zoo "
        f"examples, to ensure convergence.",
    )
    def_loss_type: LossType = Field(
        LossType.L2REL,
        description="The default loss type to use if ``loss_types`` is not given",
    )
    # The default will be filled in the Configurator
    loss_layer_names: List[str] = Field(
        None,
        description="Names of layers to be used for teacher-student losses. "
        "Names to be given in Hailo HN notation, s.a. *conv20*, "
        "*fc1*, etc. Default: the output nodes of the net (the "
        "part described by the HN)",
    )
    # def_loss_type is used by default; TODO (future): allow custom loss_types
    loss_types: List[LossType] = Field(
        None,
        description="(Same length as *loss_layer_names*) The teacher-student"
        " bi-variate loss function types to apply on the native and"
        " numeric outputs of the respective loss layers specified"
        " byloss_layer_names.  For example, ``ce`` (standing for "
        "'cross-entropy') is typically used for the classification"
        " head(s). Default: the ``def_loss_type``",
    )
    loss_factors: List[float] = Field(
        None,
        description="(Same length as *loss_layer_names*) defined bi-variate"
        " functions on native/numeric tensors produced by respective"
        " loss_layer_names , to arrive at the total loss. "
        "Default to 1 for all members.",
    )
    native_layers: List[str] = Field(list(), description="Don't quantize given layers during training")
    native_activations: ThreeWayPolicy = Field(
        ThreeWayPolicy.disabled,
        description="Keep activations native during training.",
    )
    layers_to_freeze: List[str] = Field(
        list(),
        description="Freeze (don’t modify weights&biases for) any layer whose "
        "name includes one of this list as a substring. As such, this "
        "arg can be used to freeze whole layer types/groups (e.g. "
        "pass “conv” to freeze all convolutional).",
    )
    # TODO (future): add option for custom lr scheduler
    lr_schedule_type: ScheduleType = Field(
        ScheduleType.COSINE_RESTARTS,
        description="Functional form of the learning rate decay within “decay "
        "period” - cosine decay to zero (default), exponential smooth "
        "or staircase",
    )
    decay_rate: float = Field(
        0.5,
        description="Decay factor of the learning rate at a beginning of “decay period”, from "
        "one to the next one. In default case of cosine restarts, the factor of the "
        "rate to which learning rate is restarted next time vs. the previous time.",
    )
    decay_epochs: int = Field(
        1,
        ge=0,
        description="Duration of the “decay period” in epochs. In the default case of cosine "
        "restarts, rate decays to zero (with cosine functional form) across this "
        "period, to be then restarted for the next period.",
    )
    warmup_epochs: int = Field(
        1,
        ge=0,
        description="Duration of warmup period, in epochs, applied before the starting the "
        "main schedule (e.g. cosine-restarts).",
    )
    # default is filled in default_warmup_lr
    # TODO: check its really 1/4 the learning rate
    warmup_lr: float = Field(
        None,
        description="Constant learning rate to be applied during the warmup period. Defaults "
        "to 1/4 the base learning rate.",
    )
    meta_arch: MetaArchType = Field(
        None,
        description="Meta arch is required for l2rel_channelwise_weighted loss type (currently used in yolov5m_60p)",
    )
    val_images: int = Field(
        4096,
        ge=0,
        description="Number of held-up/validation images for evaluation between epochs.",
    )
    val_batch_size: int = Field(128, ge=0, description="Batch size for the inter-epoch validation.")

    supervised_proportion: float = Field(
        0.0,
        description="Factor for weighting the supervised loss against the distillation.",
    )

    optimizer: Optimizer = Field(
        Optimizer.adam,
        description="set to 'sgd' to use simple Momentum, otherwise Adam will be used.",
    )
    stop_gradient_at_loss: bool = Field(False, description="Add stop gradient after each loss layer.")
    add_lca_default: bool = Field(
        True,
        description="Add lowest common ancestors layers of all the loss layers pairs "
        "as additional default loss layers",
    )
    train_scales: bool = Field(False, description="train scales")
    train_encoding: bool = Field(False, description="train encoding")
    train_weights: bool = Field(True, description="train weights")
    bias_only: bool = Field(False, description="train only biases (freeze weights).")
    online_quantization_bias_fix: bool = Field(False, description="online_quantization_bias_fix")
    force_pruning: bool = Field(True, description="if true the finetune will force zero weights to stay zeros")
    warmup_strategy: WarmupStrategy = Field(
        WarmupStrategy.CONSTANT,
        description="Warmup (learning rate) strategy for warmup stage of the training",
    )
    wraparound_factor: float = Field(0, ge=0, description="Factor for wraparound loss")
    t_mul: int = Field(1, ge=1, description="factor increases decay epochs after a restart")
    log_debug_data: QFTWriterMode = Field(
        QFTWriterMode.disabled, description="Configure writer mode for debug data collection"
    )
    shuffle_buffer_size: int = Field(
        1, ge=0, description="Buffer size for shuffeling the dataset. 0 will use the dataset size"
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_default(cls):
        return cls(policy=FinetunePolicy.disabled)

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.finetune.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {
            "meta_arch",
            "train_scales",
            "train_encoding",
            "online_quantization_bias_fix",
            "supervised_proportion",
            "add_lca_default",
            "train_weights",
            "t_mul",
            "log_debug_data",
        }

    @classmethod
    def advanced_keys(cls):
        return {
            "bias_only",
            "layers_to_freeze",
            "lr_schedule_type",
            "decay_rate",
            "decay_epochs",
            "warmup_epochs",
            "warmup_lr",
            "warmup_strategy",
            "optimizer",
            "wraparound_factor",
            "shuffle_buffer_size",
        }

    @classmethod
    def _verify_meta_layer_loss(cls, values):
        loss_types = values.get("loss_types", None)
        if loss_types is None:
            return
        if LossType.L2REL_CHW not in loss_types:
            return
        meta_arch = values.get("meta_arch", None)
        if meta_arch is None:
            raise ValueError(f"meta_arch is required when loss_types contain '{LossType.L2REL_CHW.value}'")
        return

    @classmethod
    def _verify_meta_def_loss(cls, values):
        def_loss = values.get("def_loss_type", None)
        if def_loss is None:
            return
        if def_loss != LossType.L2REL_CHW:
            return
        meta_arch = values.get("meta_arch", None)
        if meta_arch is None:
            raise ValueError(f"meta_arch is required when def_loss_type is '{def_loss.value}'")
        return

    @root_validator
    def validate_loss(cls, values):
        cls._verify_meta_def_loss(values)
        cls._verify_meta_layer_loss(values)
        loss_layer_names = values.get("loss_layer_names", None)
        loss_types = values.get("loss_types", None)
        loss_factors = values.get("loss_factors", None)
        if loss_layer_names is None and (loss_types is not None or loss_factors is not None):
            raise ValueError("loss_layer_names must be specified when loss_factors or loss_types are specified")
        elif loss_layer_names is None:
            return values
        loss_types = values.get("loss_types", None)
        loss_factors = values.get("loss_factors", None)
        msg = ""
        loss_layers_count = len(loss_layer_names)
        if loss_types is not None and loss_layers_count != len(loss_types):
            msg += "loss_types length must match loss_layer_names length"
        if loss_factors is not None and loss_layers_count != len(loss_factors):
            msg += "; loss_factors length must match loss_layer_names length"
        if msg:
            raise ValueError(msg)
        return values

    def info_config(self):
        default_logger().verbose("Finetune configuration:")
        raw_dict = self.raw_dict()
        for k, v in raw_dict.items():
            default_logger().verbose(f"    {k}: {v}")

    @classmethod
    def flat_layers_fields(cls):
        return {"loss_layer_names": ["loss_types", "loss_factors"], "native_layers": [], "layers_to_freeze": []}


class TrainEncodingConfig(FineTuneConfig):
    epochs: int = Field(8, ge=0, description="Epochs of training")
    add_lca_default: bool = Field(
        False,
        description="Add lowest common ancestors layers of all the loss layers pairs "
        "as additional default loss layers",
    )
    train_encoding: bool = Field(True, description="train encoding")
    train_weights: bool = Field(False, description="train weights")
    native_activations: ThreeWayPolicy = Field(
        ThreeWayPolicy.enabled,
        description="Keep activations native during training.",
    )
    warmup_strategy: WarmupStrategy = Field(
        WarmupStrategy.GRADUAL,
        description="Warmup (learning rate) strategy for warmup stage of the training",
    )
    online_quantization_bias_fix: bool = Field(True, description="online_quantization_bias_fix")
    wraparound_factor: float = Field(0.1, ge=0, description="Factor for wraparound loss")

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.train_encoding.value


class TiledSqueezeAndExciteConfig(FeatureConfigBaseModel):
    """
    This feature can modify the Squeeze and Excite block to run more efficiently on the Hailo chip.
    A more detailed explanation of the TSE algorithm can be found here
    https://arxiv.org/pdf/2107.02145.pdf

    Example commands:

    .. code-block::

        # Apply TSE to the first 3 S&E blocks with tile height of 7
        pre_quantization_optimization(se_optimization, method=tse, mode=sequential, count=3, tile_height=7)

        # Apply TSE to the first 3 S&E blocks with tile height of 9 to the 1st block, 7 to the 2nd and 5 to the 3rd
        pre_quantization_optimization(se_optimization, method=tse, mode=sequential, count=3, tile_height=[9, 7, 5])

        # Apply TSE to S&E blocks that start with avgpool1 and avgpool2 layers, with tile height of 7, 5 accordingly
        pre_quantization_optimization(se_optimization, method=tse, mode=custom, layers=[avgpool1, avgpool2], tile_height=[7, 5])

    .. note::

        This operation will modify the structure of the model's graph

    .. note::

        An in-depth explanation of the TSE algorithm - https://arxiv.org/pdf/2107.02145.pdf

    """

    method: SEOptimizationMethod = Field(description="Algorithm for Squeeze and Excite block optimization")
    mode: TiledSqueezeAndExciteMode = Field(description="How to apply the algorithm on the model")
    layers: List[str] = Field(
        None,
        description="Required when mode=custom. Set which SE blocks to optimize based on "
        "the global avgpool of the block",
    )
    count: int = Field(None, gt=0, description="Required when mode=sequential. Set how many SE blocks to optimize")
    tile_height: Union[int, List[int]] = Field(
        7,
        gt=0,
        description="Set tile height for the TSE. When list is given, it should "
        "match the layers count / the count argument. The tile has "
        "to divide the height without residue",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.se_optimization.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @validator("layers", always=True)
    def valid_layers(cls, v, values, **kwargs):
        mode = values.get("mode", None)
        if mode is None:
            return v
        if mode == TiledSqueezeAndExciteMode.custom and v is None:
            raise ValueError("layers must be provided when mode is 'custom'")
        if mode == TiledSqueezeAndExciteMode.sequential and v is not None:
            default_logger().warning("layers is ignored when mode is 'sequential'")
            return None
        return v

    @validator("count", always=True)
    def valid_count(cls, v, values, **kwargs):
        mode = values.get("mode", None)
        if mode is None:
            return v
        if mode == TiledSqueezeAndExciteMode.sequential and v is None:
            raise ValueError("count must be provided when mode is 'sequential'")
        if mode == TiledSqueezeAndExciteMode.custom and v is not None:
            default_logger().warning("count is ignored when mode is 'custom'")
            return None
        return v

    @validator("tile_height")
    def valid_tile_height(cls, v, values, **kwargs):
        if isinstance(v, int):
            return v
        mode = values.get("mode", None)
        if mode is None:
            return v
        if mode == TiledSqueezeAndExciteMode.sequential:
            count = values.get("count", None)
            if count is None:
                return v
            if count != len(v):
                raise ValueError("Tile height must be the same length as value in count")
        if mode == TiledSqueezeAndExciteMode.custom:
            layers = values.get("layers", None)
            if layers is None:
                return v
            if len(layers) != len(v):
                raise ValueError("Tile height must be the same length as layers")
        return v

    @classmethod
    def get_default(cls):
        return cls(mode=TiledSqueezeAndExciteMode.disabled, method=SEOptimizationMethod.tse)


class DeadChannelsRemovalConfig(FeatureConfigBaseModel):
    """
    Dead channels removal is channel pruning, which removes from the model any layer with both null weights and activation
    output.
    This might reduce memory consumption and improve inference time

    Example commands:

    .. code-block::

        # This will enable the algorithm
        pre_quantization_optimization(dead_channels_removal, policy=enabled)

    .. note::

        This operation will modify the structure of the model's graph

    """

    policy: DeadChannelsRemovalPolicy = Field(description="Enable or disable the dead channels removal algorithm")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.dead_channels_removal.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls(policy=DeadChannelsRemovalPolicy.disabled)


class SmartSoftmaxStatsConfig(FeatureConfigBaseModel):
    """
    SmartSoftmaxConfig is an algorithm that collects the stats on a softmax block in an efficient way
    Example commands:

    .. code-block::

        # This will enable the algorithm
        pre_quantization_optimization(smart_softmax_stats, policy=enabled)

    """

    policy: ThreeWayPolicy = Field(ThreeWayPolicy.enabled, description="Enable disable or allow the algorithm")

    optimize_bias: ThreeWayPolicy = Field(
        ThreeWayPolicy.disabled, description="Enable or disable softmax bias optimization"
    )
    optimize_bias_algorithm: SoftmaxBiasOptimizationAlgorithm = Field(
        SoftmaxBiasOptimizationAlgorithm.MSE, description="Algorithm to optimize softmax bias. default is MSE"
    )
    force_zero_centered: bool = Field(True, description="Force reduce_max output to be zero centered. default is True")
    sample_size: int = Field(4, gt=0, description="Number of samples to use for MSE optimization. default is 4")
    dc_channels: int = Field(16, gt=0, description="Number of channels per group to split the dc component to")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.smart_softmax_stats.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {"optimize_bias", "optimize_bias_algorithm", "force_zero_centered", "sample_size"}

    @classmethod
    def get_default(cls):
        return cls()


class ZeroStaticChannelsConfig(FeatureConfigBaseModel):
    """
    Zero static channels will zero out the weights of channels that have zero variances to improve quantization.

    Example commands:

    .. code-block::

        # This will enable the algorithm
        pre_quantization_optimization(zero_static_channels, policy=enabled)

    .. note::

        This operation does not modify the structure of the model's graph

    """

    policy: FeaturePolicy = Field(description="Enable or disable the zero static channels algorithm")
    eps: float = Field(
        DEFAULT_ZERO_STATIC_CHANNELS_EPSILON,
        ge=0,
        description="Threshold value to zero channels for the zero static channels algorithm",
    )
    layers: Dict[str, LayerZeroStaticChannelsConfig] = Field(dict(), description="Per layer custom configuration")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.zero_static_channels.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls(policy=FeaturePolicy.enabled, eps=DEFAULT_ZERO_STATIC_CHANNELS_EPSILON)


class SwtichConcatWithAddConfig(FeatureConfigBaseModel):
    """
    If there is concat followed by a conv, this feature converts the concat and the conv
    to 2 convs and ew-add between them

    Example commands:

    .. code-block::

        pre_quantization_optimization(switch_concat_with_add, layers=concat1, policy=enabled)
        pre_quantization_optimization(switch_concat_with_add, layers={concat*}, policy=enabled)


    .. note::

        - Relevant only if there is a concat layer with a single output of conv

    """

    layers: Dict[str, LayerSwtichConcatWithAddConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.switch_concat_with_add.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class EWAddFusingConfig(FeatureConfigBaseModel):
    """
    When EW add fusing is enabled, ew add layers will be fused into conv and add layers.
    Layers with incompatible precision modes won't be fused.

    Example commands:

    .. code-block::

        # This will enable the algorithm
        pre_quantization_optimization(ew_add_fusing, policy=enabled)

    .. note::

        This operation modifies the structure of the model's graph

    """

    policy: FeaturePolicy = Field(description="Enable or disable the ew add fusing optimization")
    infusible_ew_add_type: InfusibleEWAddType = Field(
        InfusibleEWAddType.ew_add,
        description="Decide whether to create a conv or a standalone ew add layer fusing is not possible",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.ew_add_fusing.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls(policy=FeaturePolicy.enabled, infusible_ew_add_type=InfusibleEWAddType.ew_add)

    @validator("infusible_ew_add_type")
    def valid_infusible_ew_add_type(cls, v, values, **kwargs):
        if v == InfusibleEWAddType.conv and values["policy"] == FeaturePolicy.disabled:
            raise ValueError(
                f"infusible_ew_add_type must be {InfusibleEWAddType.ew_add.value} when {cls.get_feature()}"
                f" policy is {FeaturePolicy.disabled.value}",
            )
        return v


class CheckerConfig(FeatureConfigBaseModel):
    """
    Checker Config will generate information about the quantization process using the layer analysis tool.

    Example commands:

    .. code-block::

        # This will disable the algorithm
        model_optimization_config(checker_cfg, policy=disabled)

    .. note::

        This operation does not modify the structure of the model's graph

    """

    policy: FeaturePolicy = Field(
        FeaturePolicy.enabled,
        description="Enable or disable the checker algorithm during the quantization process.",
    )
    dataset_size: Union[int, None] = Field(16, gt=0, description="Number of images used for profiling.")
    batch_size: int = Field(
        None,
        gt=0,
        description="Uses the calibration batch_size by default. Number of images used together in "
        "each inference step.",
    )
    analyze_mode: AnalysisMode = Field(
        AnalysisMode.simple,
        description="The analysis mode that will be used during the algorithm execution "
        "(simple/advanced). Simple only execute analysis on the fully "
        "quantize net, while advanced also execute layer by layer analysis. "
        "Default is simple.",
    )
    batch_norm_checker: bool = Field(
        True,
        description="Set whether the algorithm should display a batch normalization "
        "warning message when the gathered layer statistics differ from "
        "the expected distribution. Default is True.",
    )

    custom_infer_config: str = Field(
        None,
        description="Path to the File with the configuration of the model witch will modify the infer "
        "flow of the model. File needs to relative to the path where the model is been optimize or needs to be "
        "a full path to the file.",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.checker_cfg.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {
            "custom_infer_config",
        }

    @classmethod
    def get_default(cls):
        return cls()


class MixPrecisionSearchConfig(FeatureConfigBaseModel):
    """
    This algorithm aims to identify the optimal precision
    configuration for a model by utilizing the signal to noise ratio (SNR).
    SNR quantifies the extent to which a signal is corrupted by noise.
    In this context, it aids in determining the trade-off between the
    compression applied on operations and the error (or noise)
    introduced as a result of this compression.
    """

    policy: FeaturePolicy = Field(FeaturePolicy.disabled, description="Enable or disable ")
    dataset_size: int = Field(16, gt=0, description="Number of images used for profiling.")
    batch_size: int = Field(
        8,
        gt=0,
        description="Uses the calibration batch_size by default. Number of images used together in each inference step.",
    )
    snr_cap: int = Field(140, gt=0, description="The maximum SNR value to be considered in the search.")
    compresions_markers: List[float] = Field(
        [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.2],
        description="This will be the compresion markers",
    )
    optimizer: SensitivitySearch = Field(SensitivitySearch.LINEAR, description="Linear, Pareto")
    output_regulizer: MultiOutputMetric = Field(
        MultiOutputMetric.HARMONY,
        description="What function will be use to regulate the output",
    )
    comprecision_metric: ComprecisionMetric = Field(ComprecisionMetric.BOPS)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.mix_precision_search.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class NegExponentConfig(FeatureConfigBaseModel):
    """
    This command is in charge of the negative exponent algorithm.
    """

    layers: Dict[str, LayerNegExponentConfig] = Field(dict(), description="Layers to be affected")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.negative_exponent.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class ActivationClippingConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerActivationClippingConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.activation_clipping.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class WeightsClippingConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerWeightsClippingConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.weights_clipping.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class GlobalLayerDecompositionConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerDecompositionConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.layer_decomposition.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class MatmulDecompositionConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerMatmulDecompositionConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.matmul_decomposition.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class DefuseConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerDefuseConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.defuse.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class TranslationConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerTranslationConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.global_config.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class PrecisionConfig(FeatureConfigBaseModel):
    target: OptimizationTarget = Field(OptimizationTarget.SAGE, description="Indicate target platform")
    layers: Dict[str, LayerPrecisionConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.precision_config.value

    @classmethod
    def _internal_keys(cls) -> set:
        return {"target"}

    @classmethod
    def get_default(cls):
        return cls()


class GlobalResolutionReductionConfig(FeatureConfigBaseModel):
    """
    Reduce the model resolution in all input layers in order to optimize the model more efficiently.
    Marginally affects accuracy. Not supported on models that contain Fully-connected, Matmul an Cross-correlation layers,
    or when the resolution is too small.

    Example commands:

    .. code-block::

        # This will enable the algorithm, optimizing over an input shape of [128, 128]
        pre_quantization_optimization(resolution_reduction, shape=[128, 128])

    .. note::

        This operation doesn't modify the structure of the model's graph

    """

    shape: Tuple[int, int] = Field(None, description="The shape to reduce the model resolution to.")
    interpolation: ResolutionReductionInterpolationMode = Field(
        ResolutionReductionInterpolationMode.bilinear,
        description="Interpolation (default) requires dataset in the original model size, disabled required dataset in the reduced resolution.",
    )
    layers: Dict[str, LayerResolutionReductionConfig] = Field(dict(), description="Per layer custom configuration")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.resolution_reduction.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class GlobalAvgpoolReductionConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerGlobalAvgpoolReductionConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.global_avgpool_reduction.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class AddShortcutConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerAddShortcutConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.add_shortcut_layer.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class MatmulCorrectionConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerMatmulCorrectionConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.matmul_correction.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class MatmulEqualizationConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerMatmulEqualizationConfig] = Field(default_factory=dict)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.matmul_equalization.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LoadQuantConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerLoadQuantConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.use_prequantized_weights.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class ConvDecompositionConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerConvDecompositionConfig] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.conv_decomposition.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class SplitEWMultByBitSignificanceConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerSplitEWMultByBitSignificanceConfig] = Field({})

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.split_ew_mult_by_bit_significance.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class SplitFusedActivationConfig(FeatureConfigBaseModel):
    layers: Dict[str, LayerSplitFusedActivationConfig] = Field({})

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.split_fused_activation.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class QuaRotConfig(FeatureConfigBaseModel):
    policy: ThreeWayPolicy = Field(
        ThreeWayPolicy.disabled,
        description="Enable or disable the quarot algorithm",
    )
    equalize_inputs: ThreeWayPolicy = Field(
        ThreeWayPolicy.allowed,
        description="Enable or disable the equalization of the inputs",
    )
    equalize_outputs: ThreeWayPolicy = Field(
        ThreeWayPolicy.allowed,
        description="Enable or disable the equalization of the outputs",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.quarot.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class ConvA16W4Config(FeatureConfigBaseModel):
    """
    If the convolution layer's kernel size is 1x1, this command decompose the convolution layer work as a8_w4 on a
    split precision of the input.

    Example commands:

    .. code-block::

        pre_quantization_optimization(conv_a16_w4, layers=conv1, policy=enabled)
        pre_quantization_optimization(conv_a16_w4, layers={conv*}, policy=enabled)


    .. note::

        - Relevant only if the convolution layer's kernel size is 1x1.
        - This will overwrite the precision mode of that layer set by quantization_param.

    """

    layers: Dict[str, LayerConvA16W4Config] = Field(dict())

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.conv_a16_w4.value

    @classmethod
    def _internal_keys(cls) -> set:
        return set()

    @classmethod
    def get_default(cls):
        return cls()
