from abc import ABC
from typing import Tuple, Union

from pydantic.v1 import Field, root_validator, validator

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import (
    BaseConfigBaseModel,
    _value_to_str,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_NULL_CHANNELS_CUTOFF_FACTOR,
    ActivationClippingMode,
    ActivationFitPolicy,
    BiasCorrectionPolicy,
    BiasMode,
    CalcKernelMode,
    EqualizationPolicy,
    EqualizationTargetPolicy,
    FeaturePolicy,
    IgnoreHwLimitationAssertionPolicy,
    LayerFeaturePolicy,
    MatmulCorrectionType,
    MOConfigCommand,
    ModelOptimizationCommand,
    PostQuantizationFeature,
    PrecisionMode,
    PreQuantizationDefuseType,
    PreQuantizationFeature,
    ResolutionReductionInterpolationMode,
    SplitFusedActivationPolicy,
    ThreeWayPolicy,
    WeightsClippingMode,
)
from hailo_model_optimization.acceleras.utils.logger import default_logger


class LayerConfigBaseModel(BaseConfigBaseModel, ABC):
    @classmethod
    def to_cmd(cls, raw_dict, layers):
        """
        Creates alls command from configuration.
        Requires raw dict of the config, and layers names of the config.

        Args:
            raw_dict: of the config to convert to command
            layers: layers for which the configuration applied to.

        Returns:
            string command

        """
        command = cls.get_command()
        feature = cls.get_feature()
        kwargs = ", ".join(f"{k}={_value_to_str(v)}" for k, v in sorted(raw_dict.items()))
        layers = _value_to_str(layers)
        if not kwargs:
            return ""
        if feature is None:
            full_cmd = f"{command}({layers}, {kwargs})"
        else:
            full_cmd = f"{command}({feature}, layers={layers}, {kwargs})"
        return full_cmd


class LayerEqualizationConfig(LayerConfigBaseModel):
    """
    This sub-command allows configuring the equalization behavior per layer.
    Allowed policy means the behavior derives from the algorithm config.

    Example commands:

    .. code-block::

        # Disable equalization on conv1 and conv2
        pre_quantization_optimization(equalization, layers=[conv1, conv2], policy=disabled)

        # Disable equalization on all conv layers.
        pre_quantization_optimization(equalization, layers={conv*}, policy=disabled)

    .. note::

        - Not all layers support equalization
        - Layers are related to other
        - Disabling 1 layer, disables all related layers
        - Enabling 1 layer won't enable the related layers (it has to be done manually)

    """

    policy: EqualizationPolicy = Field(
        EqualizationPolicy.allowed,
        description="Set equalization behavior to given layer. (default is allowed)",
    )
    equalization_target: EqualizationTargetPolicy = Field(
        EqualizationTargetPolicy.default, description="Set equalization target to activation only or kernel only"
    )

    force_transparent: ThreeWayPolicy = Field(
        ThreeWayPolicy.allowed, description="Force the layer to behave as transparent for equalization"
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.equalization.value

    @classmethod
    def _internal_keys(cls):
        return set("force_transparent")

    @classmethod
    def get_default(cls):
        return cls()


class LayerZeroStaticChannelsConfig(LayerConfigBaseModel):
    """
    This sub-command allows configuring the zerostatic behavior per layer.
    Example commands:

    .. code-block::

        # Disable zero static on conv1 and conv2
        pre_quantization_optimization(zero_static_channels,layers=[conv1, conv2], policy=disabled)

        # Disable zero static on all conv layers.
        pre_quantization_optimization(zero_static_channels,layers={conv*}, policy=disabled)

    .. note::
        - Not all layers support zero static channels
        - if allowed layer gets the behavior from the generic algorithm policy
        - if disabled/enabled are explicit, the layer will beahce as the LayerFeaturePolicy
    """

    policy: LayerFeaturePolicy = Field(
        LayerFeaturePolicy.allowed,
        description="Set zero_static behavior to given layer. (default is allowed)",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.zero_static_channels.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerSwtichConcatWithAddConfig(LayerConfigBaseModel):
    """
    If there is concat followed by a conv, this feature converts the concat and the conv
    to 2 convs and ew-add between them

    Example commands:

    .. code-block::

        pre_quantization_optimization(switch_concat_with_add, layers=concat1, policy=enabled)
        pre_quantization_optimization(switch_concat_with_add, layers={concat*}, policy=enabled)


    .. note::

        - Relevant only if there is a concat layer with a single output of conv

    """

    policy: LayerFeaturePolicy = Field(
        LayerFeaturePolicy.disabled,
        description="replace concat and conv with 2 convs and add",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.switch_concat_with_add.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerAddShortcutConfig(LayerConfigBaseModel):
    """
    Adds an activation layer between "layer" and "target"
    removes original edge between,
    activation is linear (by default)
    before : layer -> target
    after  : layer -> act -> target

    Example commands:

    .. code-block::

        # Adds activation layer (linear) between conv8 and conv10
        pre_quantization_optimization(add_shortcut_layer, layers=conv8, target=conv10)

        # Adds activation layer (linear) from conv3 to conv4 and to conv5
        pre_quantization_optimization(add_shortcut_layer, layers=conv3, target=[conv4, conv5])
    """

    target: Union[str, Tuple] = Field(None)
    name: str = Field(None, description="Name of added shortcut layer. defualts to concatination of layer-target")
    activation: str = Field("linear")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.add_shortcut_layer.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls(name=None)


class LayerMatmulCorrectionConfig(LayerConfigBaseModel):
    """
    docstring
    pre_quantization_optimization(matmul_correction, layers=matmul1, correction_type=zp_comp_weights)
    pre_quantization_optimization(matmul_correction, layers=[matmul2,matmul4], correction_type=zp_comp_block)

    """

    correction_type: str = Field(
        MatmulCorrectionType.ZP_COMP_WEIGHTS,
        description="Type of correction to apply. 'zp_comp_weights',"
        "'zp_comp_block' , 'zp_comp_block_2' or 'zp_comp_block_3'",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.matmul_correction.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerMatmulEqualizationConfig(LayerConfigBaseModel):
    """
    docstring
    pre_quantization_optimization(matmul_equalization, layers=matmul1, policy=enabled)
    pre_quantization_optimization(matmul_equalization, layers=[matmul2,matmul4], policy=disabled)

    """

    policy: ThreeWayPolicy = Field(ThreeWayPolicy.allowed, description="Enables matmul Equalization on a given Layer")
    matmul_bias: ThreeWayPolicy = Field(
        ThreeWayPolicy.disabled, description="Adds offset to the matmult to clean the Dc"
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.matmul_equalization.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerLoadQuantConfig(LayerConfigBaseModel):
    """
    docstring

    pre_quantization_optimization(use_prequantized_weights, layers=[matmul2,matmul4], policy=disabled, bits=4)
    """

    mode: CalcKernelMode = Field(CalcKernelMode.kernel_vals, description="Mode of operation to calc the kernel scale")
    bits = Field(4, description="Number of bits quantized for the weights currently supports : [4, 8]")
    groups = Field(-1, description="Number of quantization groups currently supportes :[-1] (number of channels)")
    policy: LayerFeaturePolicy = Field(
        LayerFeaturePolicy.allowed, description="Set pre_weights behavior to given layer. (default is disabled)"
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.use_prequantized_weights.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerConvDecompositionConfig(LayerConfigBaseModel):
    """
    This sub commands allows decomosing a single linear/conv layer to multiple layers, to support sub-group quantization.
    pre_quantization_optimization(conv_decomposition, layers=[conv1],sub_group_size=128)
    """

    sub_group_size: int = Field(128, description="Size of the sub group")
    pm_ew_adds: str = Field("a16_w16", description="Precision mode for the ew_adds")
    allow_equlize_block: bool = Field(False, description="Allow block equalization")
    sort_channels_by_stats: bool = Field(False, description="Sort input channels")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.conv_decomposition.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerDecompositionConfig(LayerConfigBaseModel):
    """
    This sub commands allows toggling layers to decomposition mode, which means 16-bit layers will be implemented with
    8-bit layers.

    Example commands:

    .. code-block::

        # This will decompose a specific layer to increase its precision.
        pre_quantization_optimization(layer_decomposition, layers=[conv1], policy=disabled)
        pre_quantization_optimization(layer_decomposition, layers=[conv17, conv18], policy=enabled)
    """

    policy: LayerFeaturePolicy = Field(LayerFeaturePolicy.allowed)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.layer_decomposition.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerMatmulDecompositionConfig(LayerConfigBaseModel):
    """
    This sub commands allows toggling Matmul layers to decomposition mode, which means 16-bit layers will
    be implemented with 8-bit layers.

    Example commands:

    .. code-block::

        # This will decompose a specific layer to increase its precision.
        pre_quantization_optimization(matmul_decomposition, layers=[matmul1], policy=disabled)
        pre_quantization_optimization(matmul_decomposition, layers=[matmul1, matmul2], policy=enabled, precision_mode=a16_w8)
    """

    policy: LayerFeaturePolicy = Field(LayerFeaturePolicy.disabled)
    precision_mode: PrecisionMode = Field(
        PrecisionMode.a16_w8,
        description="Matmul can be decompose into different modes, "
        "default is a16_w8, there also can be set to be a16_w16",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.matmul_decomposition.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerBiasCorrectionConfig(LayerConfigBaseModel):
    """
    This sub-command allows enabling or disabling the Iterative Bias Correction (IBC) algorithm on a per-layer basis.
    Allowed policy means the behavior derives from the algorithm config

    Example commands:

    .. code-block::

        # This will enable IBC for a specific layer
        post_quantization_optimization(bias_correction, layers=[conv1], policy=enabled)

        # This will disable IBC for conv layers and enable for the other layers
        post_quantization_optimization(bias_correction, policy=enabled)
        post_quantization_optimization(bias_correction, layers={conv*}, policy=disabled)

    """

    policy: BiasCorrectionPolicy = Field(
        BiasCorrectionPolicy.allowed,
        description="Set bias correction behavior to given layer. (default is allowed)",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.bias_correction.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerAdaRoundConfig(LayerConfigBaseModel):
    """
    This sub commands allow toggling layers in the adaround algorithm individually

    Example commands:

    .. code-block::

        # This will enable AdaRound for a specific layer
        post_quantization_optimization(adaround, layers=[conv1], policy=disabled)
        post_quantization_optimization(adaround, layers=[conv17, conv18], policy=enabled)
    """

    policy: LayerFeaturePolicy = Field(LayerFeaturePolicy.allowed)

    epochs: int = Field(None, description="Amount of train epochs for a specific layer")
    weight: float = Field(None, gt=0, description="Weight of round regularization")
    b_range: Tuple[float, float] = Field(None, description="Temperature decay range")
    decay_start: float = Field(
        None,
        ge=0,
        le=1,
        description="Ratio of round train without round regularization decay (b)",
    )
    train_bias: bool = Field(None, description="Toggle bias training")
    warmup: float = Field(None, ge=0, le=1, description="Ratio of warmup epochs out of epochs")
    dataset_size: int = Field(None, gt=0, description="Data samples count for the train stage of the specified layer")
    batch_size: int = Field(None, gt=0, description="Batch size for train / infer of a layer")

    # TODO: add batch size (should affect both inference and train)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.post_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PostQuantizationFeature.adaround.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerActivationClippingConfig(LayerConfigBaseModel):
    """
    By default, the model optimization does not clip layers' activations during quantization.
    This command can be used to change this behavior for selected layers and apply activation clipping when running the
    quantization API.
    This command may be useful in order to decrease quantization related degradation in case of outlier activation
    values.

    * ``disabled`` mode doesn't take clipping values, and disables any activation clipping mode previously set to the layer (This is the default).
    * ``manual`` mode uses the clipping values as given.
    * ``percentile`` mode calculates layer-wise percentiles (clipping values are percentiles 0 to 100).

    .. note:: Percentiles based activation clipping requires several iterations of statistics
        collection, so quantization might take a longer time to finish.

    Example commands:

    .. code-block::

        pre_quantization_optimization(activation_clipping, layers=[conv1], mode=manual, clipping_values=[0.188, 1.3332])
        pre_quantization_optimization(activation_clipping, layers=[conv1, conv2], mode=percentile, clipping_values=[0.5, 99.5])
        pre_quantization_optimization(activation_clipping, layers={conv*}, mode=disabled)

    """

    mode: ActivationClippingMode = Field(description="Mode of operation, described above")
    clipping_values: Tuple[float, float] = Field(
        None,
        description="Clip value, required when mode is percentile or manual",
    )
    recollect_stats: bool = Field(False, description="Indicates whether stats should be collected after clip")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.activation_clipping.value

    @root_validator
    def validate_all_fields(cls, values):
        mode = values.get("mode", None)
        if mode is None:
            return values
        clipping_values = values.get("clipping_values", None)
        if mode is not ActivationClippingMode.disabled and clipping_values is None:
            raise ValueError(f"Clipping values must be defined when mode={mode.value}")
        return values

    @validator("mode", pre=True)
    def replace_percentile_force(cls, v):
        if v == "percentile_force":
            return "percentile"
        return v

    @validator("clipping_values")
    def valid_clipping_values(cls, v, values, config, field):
        def _is_percentile(x):
            return x in {ActivationClippingMode.percentile}

        if isinstance(v, tuple):
            v = list(v)

        if v is None:
            return v
        mode = values.get("mode", None)
        if mode is None:
            return v
        if not v[0] < v[1]:
            raise ValueError(f"Clipping values {v} min value has to be less than max value")
        if _is_percentile(mode) and not all(0 <= i <= 100 for i in v):
            raise ValueError(f"Clipping values {v} has to be in range [0, 100] when in percentile mode")
        if mode is ActivationClippingMode.disabled and v is not None:
            default_logger().debug(f"Clipping values are ignored when mode={mode.value}")
            return None
        return v

    @classmethod
    def _internal_keys(cls):
        return set("recollect_stats")

    @classmethod
    def get_default(cls):
        return cls(mode=ActivationClippingMode.disabled)


class LayerWeightsClippingConfig(LayerConfigBaseModel):
    """
    This command allows changing this behavior for selected layers and applying weights clipping when running the
    quantization API.
    This command may be useful in order to decrease quantization related degradation in case of outlier weight values.
    It is only applicable to the layers that have weights.

    * ``disabled`` mode doesn't take clipping values, and disables any weights clipping mode previously set to the layer.
    * ``manual`` mode uses the clipping values as given.
    * ``percentile`` mode calculates layer-wise percentiles (clipping values are percentiles 0 to 100).
    * ``mmse`` mode doesn't take clipping values, and uses `Minimum Mean Square Estimators` to clip the weights of the layer.
    * ``mmse_if4b`` similar to mmse, when the layer uses 4bit weights, and disables clipping when it uses 8-bit weights. (This is the default)

    Example commands:

    .. code-block::

        pre_quantization_optimization(weights_clipping, layers=[conv2], mode=manual, clipping_values=[-0.1, 0.8])
        pre_quantization_optimization(weights_clipping, layers=[conv3], mode=percentile, clipping_values=[1.0, 99.0])
        pre_quantization_optimization(weights_clipping, layers={conv*}, mode=mmse)
        pre_quantization_optimization(weights_clipping, layers=[conv3, conv4], mode=mmse_if4b)
        pre_quantization_optimization(weights_clipping, layers={conv*}, mode=disabled)

    .. note:: The dynamic range of the weights is symmetric even if the clipping values are not
        symmetric.

    """

    mode: WeightsClippingMode = Field(description="Mode of operation, described above")
    clipping_values: Tuple[float, float] = Field(
        None,
        description="Clip value, required when mode is percentile or manual",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.weights_clipping.value

    @root_validator
    def validate_all_fields(cls, values):
        mode = values.get("mode", None)
        if mode is None:
            return values
        clipping_values = values.get("clipping_values", None)
        if mode in {WeightsClippingMode.percentile, WeightsClippingMode.manual} and clipping_values is None:
            raise ValueError(f"Clipping values must be defined when mode={mode.value}")
        return values

    @validator("clipping_values")
    def valid_clipping_values(cls, v, values, config, field):
        def _is_percentile(x):
            return x in {WeightsClippingMode.percentile}

        if isinstance(v, tuple):
            v = list(v)

        if v is None:
            return v
        mode = values.get("mode", None)
        if mode is None:
            return v
        if not v[0] < v[1]:
            raise ValueError(f"Clipping values {v} min value has to be less than max value")
        if _is_percentile(mode) and not all(0 <= i <= 100 for i in v):
            raise ValueError(f"Clipping values {v} has to be in range [0, 100] when in percentile mode")
        if mode is WeightsClippingMode.disabled and v is not None:
            default_logger().debug(f"Clipping values are ignored when mode={mode.value}")
            return None
        return v

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls(mode=WeightsClippingMode.mmse_if4b)


class LayerDefuseConfig(LayerConfigBaseModel):
    """
    This command allows defusing layer according to the defuse type:

    INPUT FEATURES

    Defuse input features for a selected dense or conv layer to a selected number of splits.
    It can also be used to disable defusing of a layer.
    Example commands:

    .. code-block::

        pre_quantization_optimization(defuse, layers=fc1, num_splits=2, defuse_type=INPUT_FEATURES)
        # this will disable the fusing of fc2
        pre_quantization_optimization(defuse, layers=fc2, num_splits=1, defuse_type=INPUT_FEATURES)

    .. note:: num_splits might be overwritten by a larger number due to hw limitations.

    MHA

    Allows defusing multi-head attention block, represented by its first matmul, to a selected number of
    splits.

    Example commands:

    .. code-block::

        pre_quantization_optimization(defuse, layers=matmul1, num_splits=2, defuse_type=MHA)
    """

    num_splits: int = Field(None, description="number of splits required")
    defuse_type: PreQuantizationDefuseType = Field(None, description="defuse type")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @root_validator
    def validate_all_fields(cls, values):
        num_splits = values.get("num_splits", None)
        defuse_type = values.get("defuse_type", None)
        if num_splits is None and defuse_type is None:
            return values
        if num_splits is None:
            raise ValueError("Number of splits must be defined.")
        if num_splits < 1:
            raise ValueError("Number of splits must be positive.")
        if defuse_type is None:
            raise ValueError("Defuse type must be defined.")
        return values

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.defuse.value


class LayerTranslationConfig(LayerConfigBaseModel):
    max_elementwise_feed_repeat: int = Field(None, ge=1, le=16)
    max_bias_feed_repeat: int = Field(None, ge=1, le=32)
    activation_fit: ActivationFitPolicy = Field(ActivationFitPolicy.allowed)
    null_channels_cutoff_factor: float = Field(DEFAULT_NULL_CHANNELS_CUTOFF_FACTOR)
    force_range_in: Tuple[float, float] = Field(None)
    force_range_out: Tuple[float, float] = Field(None)
    force_range_preact: Tuple[float, float] = Field(None)
    force_range_index: int = Field(None, ge=0)
    weak_force_range_out: FeaturePolicy = Field(FeaturePolicy.disabled)
    force_shift: int = Field(None, ge=0, le=32)
    ignore_hw_limitation_assertion: IgnoreHwLimitationAssertionPolicy = Field(IgnoreHwLimitationAssertionPolicy.allowed)
    input_normalization: FeaturePolicy = Field(FeaturePolicy.disabled)
    activation_symmetric_range: LayerFeaturePolicy = Field(LayerFeaturePolicy.allowed)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.quantization_param.value

    @classmethod
    def get_feature(cls):
        return None

    @validator("max_elementwise_feed_repeat", "max_bias_feed_repeat", "force_shift", pre=True)
    def validate_not_float(cls, v, **kwargs):
        try:
            int_v = int(v)
            float_v = float(v)
        except (ValueError, TypeError):
            # treated by pydantic
            return v
        if int_v != float_v:
            raise ValueError("value is not a valid integer")
        return v

    @validator("force_range_in", "force_range_out", "force_range_preact")
    def validate_range(cls, v, **kwargs):
        if v is None:
            return v
        if isinstance(v, tuple):
            v = list(v)
        if v[0] > 0 or v[1] < 0:
            raise ValueError("0 must be in range")
        if v[1] < v[0]:
            # Shouldn't happen, assuming max >= 0 and min <= 0
            raise ValueError("Max value must be greater than min")
        return v

    @classmethod
    def _internal_keys(cls):
        return {
            "force_range_in",
            "force_range_out",
            "force_range_preact",
            "force_range_index",
            "input_normalization",
            "activation_symmetric_range",
        }

    @classmethod
    def get_default(cls):
        return cls()


class LayerPrecisionConfig(LayerConfigBaseModel, validate_assignment=True):
    quantization_groups: int = Field(None)
    quantization_weight_groups: int = Field(None)
    precision_mode: PrecisionMode = Field(None)
    bias_mode: BiasMode = Field(None)
    signed_output: bool = Field(None)

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.quantization_param.value

    @classmethod
    def get_feature(cls):
        return None

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    def fill_default_config(self, layer):
        # TODO: remove this function once we fully remove defaults from hailo_nn
        if self.precision_mode is None:
            self.precision_mode = layer.get_default_precision_mode()
        if self.bias_mode is None:
            self.bias_mode = layer.get_default_bias_mode()
        if self.quantization_groups is None:
            self.quantization_groups = layer.get_default_quantization_groups()
        if self.quantization_groups == -1:
            self.quantization_groups = layer.activation_atomic_op.num_of_channels

    @validator("quantization_groups")
    def validate_quantization_groups(cls, v, **kwargs):
        if v is None:
            return v
        if v < 1 and v != -1:
            raise ValueError("quantization_groups must be positive number or -1.")
        return v


class LayerSplitEWMultByBitSignificanceConfig(LayerConfigBaseModel):
    """
    This command allows splitting element-wise multiplication layers by bit significant to allow higher precision.

    .. code-block::

        pre_quantization_optimization(split_ew_mult_by_bit_significance, layers=ew_mult1, num_splits=2)

    """

    num_splits: int = Field(None, description="Number of splits for ew_mult layer")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @validator("num_splits")
    def validate_num_splits(cls, num_splits, **kwargs):
        if num_splits not in [2, 3]:
            raise ValueError("num_splits must be either 2 or 3.")
        return num_splits

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.split_ew_mult_by_bit_significance.value


class LayerGlobalAvgpoolReductionConfig(LayerConfigBaseModel):
    """
    This command allows reducing the spatial dimensions for global avgpool layers using additional avgpool layer.
    The kernel size of the added avgpool layer will be [1, h // division_factors[0], w // division_factors[1], 1]

    .. code-block::

        pre_quantization_optimization(global_avgpool_reduction, layers=avgpool1, division_factors=[4, 4])
        # this will disable the reduction of avgpool1
        pre_quantization_optimization(global_avgpool_reduction, layers=avgpool1, division_factors=[1, 1])

    """

    division_factors: Tuple[int, int] = Field(None, description="division of the kernel height and width")

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @validator("division_factors")
    def validate_division_factors(cls, factors, **kwargs):
        if factors is None:
            raise ValueError("Division factors must be defined.")
        if any(x < 1 for x in factors):
            raise ValueError("Division factors must be positive.")
        return factors

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.global_avgpool_reduction.value


class LayerResolutionReductionConfig(LayerConfigBaseModel):
    """
    Sub-command for configuring resolution reduction per input layer, affecting its connected component.
    Reduce the resolution in order to optimize more efficiently. Marginally affects accuracy.
    Not supported when containing Fully-connected, Matmul an Cross-correlation layers, or when the resolution is too small.

    Example commands:

    .. code-block::

        # This will enable the algorithm for input_layer1 connected component, optimizing over an input shape of [128, 128]
        pre_quantization_optimization(resolution_reduction, layers=input_layer1, shape=[128, 128])

    .. note::

        This operation doesn't modify the structure of the model's graph

    """

    shape: Tuple[int, int] = Field(None, description="The shape to reduce the component resolution to.")
    interpolation: ResolutionReductionInterpolationMode = Field(
        None,
        description="Interpolation (default) requires dataset in the original model size, disabled required dataset in the reduced resolution.",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.resolution_reduction.value


class LayerNegExponentConfig(LayerConfigBaseModel):
    """
    During the process of quantization, certain layers may experience bit loss,
    resulting in reduced precision of the output. To mitigate this issue,
    this command can be enabled the addition of extra layers.
    by setting rank to 1 this layer introduces a helper layer that mitigates the
    the bits lost in the quantized output this can cause a decrease on the FPS
    of the network.
    by setting rank to 0 no layer will be introduces and the loss of bits will
    be delegated to the output.

    Example commands:

    .. code-block::

        # This will enable the split of conv3 into two layers to not lose precision by a negative exponent >= 1
        model_optimization_config(negative_exponent, layers=[conv3], split_threshold=1 rank=1)

    .. note::

        This operation does modify the structure of the model's graph

    """

    split_threshold: int = Field(2, gt=0, description="Split the layer at the given negative exponent.")
    rank: int = Field(1, le=1, ge=0, description="How many new layers should be added to the model")
    auto_clip: LayerFeaturePolicy = Field(LayerFeaturePolicy.disabled, description="Clip the range of the accumulator.")
    auto_remove_offset: LayerFeaturePolicy = Field(
        LayerFeaturePolicy.disabled,
        description="Remove Offsets that are not reach by the range on calibrations.",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_config.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()

    @classmethod
    def get_feature(cls):
        return MOConfigCommand.negative_exponent.value


class LayerSplitFusedActivationConfig(LayerConfigBaseModel):
    """
    Sub command for splitting fused activation from a main layer.

    Example commands:

    .. code-block::

        # This will split the activation that fused on conv1 to conv1 layer with linear activation and standalone activation layer.
        pre_quantization_optimization(split_fused_activation, layers=conv1, policy=enabled)

        # This will split the activations from all conv layers.
        pre_quantization_optimization(split_fused_activation, layers={conv*}, policy=enabled)
    """

    policy: SplitFusedActivationPolicy = Field(
        SplitFusedActivationPolicy.enabled,
        description="Split fused activation policy",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.split_fused_activation.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()


class LayerConvA16W4Config(LayerConfigBaseModel):
    """
    If the convolution layer's kernel size is 1x1, this command decompose the convolution layer work as a8_w4 on a
    split precision of the input.

    Example commands:

    .. code-block::

        pre_quantization_optimization(conv_a16_w4, layers=conv1, policy=enabled)
        pre_quantization_optimization(conv_a16_w4, layers={conv*}, policy=enabled)


    .. note::

        - Relevant only if the convolution layer's kernel size is 1x1.
        - This will overwrite the precision mode of that layer set by quantization_param.

    """

    policy: LayerFeaturePolicy = Field(
        LayerFeaturePolicy.disabled,
        description="enable or disable the a16_w4 decomposition",
    )
    shift_high: LayerFeaturePolicy = Field(
        LayerFeaturePolicy.disabled,
        description="Shift the high by 1 (from 7 bit to 8 bit), so that it would suffer less from MAC shift",
    )

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.pre_quantization_optimization.value

    @classmethod
    def get_feature(cls):
        return PreQuantizationFeature.conv_a16_w4.value

    @classmethod
    def _internal_keys(cls):
        return set()

    @classmethod
    def get_default(cls):
        return cls()
