import collections
import json
from typing import Dict, Optional

from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import BaseConfigBaseModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import (
    ActivationClippingConfig,
    AdaRoundConfig,
    AddShortcutConfig,
    BlockRoundTrainingConfig,
    CalibrationConfig,
    CheckerConfig,
    CompressionConfig,
    ConvA16W4Config,
    ConvDecompositionConfig,
    DeadChannelsRemovalConfig,
    DefuseConfig,
    EWAddFusingConfig,
    FineTuneConfig,
    GlobalAvgpoolReductionConfig,
    GlobalBiasCorrectionConfig,
    GlobalDeadLayersRemovalConfig,
    GlobalEqualizationConfig,
    GlobalLayerDecompositionConfig,
    GlobalLayerNormDecompositionConfig,
    GlobalResolutionReductionConfig,
    GlobalsMOConfig,
    LoadQuantConfig,
    MatmulCorrectionConfig,
    MatmulDecompositionConfig,
    MatmulEqualizationConfig,
    MixPrecisionSearchConfig,
    NegExponentConfig,
    PrecisionConfig,
    QuaRotConfig,
    SmartSoftmaxStatsConfig,
    SplitEWMultByBitSignificanceConfig,
    SplitFusedActivationConfig,
    SwtichConcatWithAddConfig,
    TiledSqueezeAndExciteConfig,
    TrainEncodingConfig,
    TranslationConfig,
    WeightsClippingConfig,
    ZeroStaticChannelsConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    MINIMUM_PARAMS_FOR_COMPRESSION,
    RECOMMENDED_CALIBSET_SIZE_FOR_BN_CHECKER,
    RECOMMENDED_DATASET_SIZE,
    AdaRoundMode,
    ModelOptimizationCommand,
    OptimizationTarget,
)


class ModelOptimizationFlavor(BaseConfigBaseModel):
    """
    Configure the model optimization effort by setting compression level and optimization level.
    The flavor's algorithm will behave as default,
    any algorithm-specific configuration will override the flavor's default config

    Default values:
        - compression_level: 1
        - optimization_level: 2 for GPU and 1024 images, 1 for GPU and less than 1024 images, and 0 for CPU only.
        - batch_size: check default of each algorithm (usually 8 or 32)


    .. internal::
        IMPORTANT: the script for this function is used for the hailo_dataflow_compiler pdf guide -- apply judgment when changing.


    Optimization levels: (might change every version)
        - -100 nothing is applied - all default algorithms are switched off
        - 0 - Equalization
        - 1 - Equalization + Iterative bias correction
        - 2 - Equalization + Finetune with 4 epochs & 1024 images
        - 3 - Equalization + Adaround with 320 epochs & 256 images on all layers
        - 4 - Equalization + Adaround with 320 epochs & 1024 images on all layers

    Compression levels: (might change every version)
        - 0 - nothing is applied
        - 1 - auto 4bit is set to 0.2 if network is large enough (20% of the weights)
        - 2 - auto 4bit is set to 0.4 if network is large enough (40% of the weights)
        - 3 - auto 4bit is set to 0.6 if network is large enough (60% of the weights)
        - 4 - auto 4bit is set to 0.8 if network is large enough (80% of the weights)
        - 5 - auto 4bit is set to 1.0 if network is large enough (100% of the weights)

    Example commands:

    .. code-block::

        model_optimization_flavor(optimization_level=4)
        model_optimization_flavor(compression_level=2)
        model_optimization_flavor(optimization_level=2, compression_level=1)
        model_optimization_flavor(optimization_level=2, batch_size=4)
    """

    optimization_level: Optional[int] = Field(
        ge=-100,
        le=4,
        description="Optimization level, higher is better but longer, improves accuracy",
    )
    compression_level: Optional[int] = Field(
        ge=0,
        le=5,
        description="Compression level, higher is better but increases degradation, improves fps and latency",
    )
    batch_size: Optional[int] = Field(
        ge=1,
        description="Batch size for the algorithms (adaround, finetune, calibration)",
    )

    @classmethod
    def get_default(cls):
        raise NotImplementedError

    @classmethod
    def get_command(cls):
        return ModelOptimizationCommand.model_optimization_flavor.value

    @classmethod
    def get_feature(cls):
        return None

    @classmethod
    def _internal_keys(cls) -> set:
        """
        More internal config of feature, read BaseConfigBaseModel's internal_keys for more info
        Returns:
            set with internal_keys keys
        """
        return set()

    @staticmethod
    def get_default_optimization(has_gpu, dataset_length, logger):
        """
        Get the default optimization level based on gpu and data availability

        Args:
            has_gpu: boolean, whether there's a gpu or not
            has_enough_data: boolean, whether there's enough data or not (currently 1024)
            logger: logger object, to log if there's an availability issue

        Return:
            int, the default optimization level

        """
        reasons = []
        optimization = 2
        if dataset_length < RECOMMENDED_DATASET_SIZE:
            reasons.append(f"there's less data than the recommended amount ({RECOMMENDED_DATASET_SIZE})")
            optimization = 1
        if not has_gpu:
            reasons.append("there's no available GPU")
            optimization = 0

        if optimization == 2:
            logger.info(f"Using default optimization level of {optimization}")
        else:
            msg = (
                f"Reducing optimization level to {optimization} "
                "(the accuracy won't be optimized and compression won't be used) because "
            )
            if logger is not None:
                # We have up to 2 reasons, so and for conjunction is suitable.
                logger.warning(msg + ", and ".join(reasons))

        return optimization

    def get_default_compression(self, optimization, parameters_count, logger=None):
        """
        Get the default compression level based on the current optimization level

        Args:
            optimization: int, current optimization level
            user_optimization_level: int, optimization level requested by the user
            logger: sdk logger object

        Return:
            int, the default compression level

        """
        user_optimization_level = self.optimization_level
        size_cond = parameters_count > MINIMUM_PARAMS_FOR_COMPRESSION
        if optimization > 1 and size_cond:
            compression = 1
            if logger is not None:
                logger.info(f"Using default compression level of {compression}")
        else:
            compression = 0
            if (logger is not None) and (user_optimization_level in {0, 1}):
                logger.warning(
                    "Reducing compression level to 0 because requested optimization level equal or less than 1",
                )
        return compression

    def get_flavor_config(self, has_gpu, dataset_length, parameters_count, target, logger=None):
        if self.optimization_level is not None:
            optimization = self.optimization_level
        else:
            optimization = self.get_default_optimization(has_gpu, dataset_length, logger)

        if self.compression_level is not None:
            compression = self.compression_level
        else:
            compression = self.get_default_compression(optimization, parameters_count, logger)
        optimization_cfg = self.get_optimization_config(optimization, dataset_length, logger)
        compression_cfg = self.get_compression_config(compression, parameters_count, target, logger)
        conflicting_keys = optimization_cfg.keys() & compression_cfg.keys()
        if conflicting_keys:
            raise RuntimeError(
                f"Got conflicting keys in the compression and optimization configurations: {conflicting_keys}",
            )

        config = dict()
        config.update(optimization_cfg)
        config.update(compression_cfg)
        self.optimization_level = optimization
        self.compression_level = compression

        return config

    def get_optimization_config(self, level, dataset_length, logger=None):
        if self.batch_size is None:
            common_kwargs = {}
            if level == 0:
                common_kwargs = {"batch_size": 1}
        else:
            common_kwargs = {"batch_size": self.batch_size}
        if level < 1 and logger is not None:
            logger.warning(
                "Running model optimization with zero level of optimization "
                "is not recommended for production use and might lead to suboptimal accuracy results"
            )
        cfg = {
            "equalization": GlobalEqualizationConfig(policy="enabled").raw_dict(),
            "calibration": CalibrationConfig(**common_kwargs).raw_dict(),
            "checker_cfg": CheckerConfig(policy="enabled").raw_dict(),
        }

        cfg["calibration"]["calibset_size"] = int(min(cfg["calibration"]["calibset_size"], dataset_length))
        cfg["checker_cfg"]["dataset_size"] = int(min(cfg["checker_cfg"]["dataset_size"], dataset_length))
        cfg["checker_cfg"]["batch_norm_checker"] = (
            cfg["checker_cfg"]["batch_norm_checker"]
            and cfg["calibration"]["calibset_size"] >= RECOMMENDED_CALIBSET_SIZE_FOR_BN_CHECKER
        )

        if level == -100:
            cfg["checker_cfg"]["policy"] = "disabled"
            cfg["equalization"] = GlobalEqualizationConfig(policy="disabled").raw_dict()
            cfg["globals"] = GlobalsMOConfig(output_16bit="disabled").raw_dict()
        elif level == 0:
            cfg["checker_cfg"]["policy"] = "disabled"
        elif level == 1:
            cfg["bias_correction"] = GlobalBiasCorrectionConfig(policy="enabled", **common_kwargs).raw_dict()
        elif level == 2:
            cfg["finetune"] = FineTuneConfig(policy="enabled", epochs=4, dataset_size=1024, **common_kwargs).raw_dict()
        elif level == 3:
            cfg["adaround"] = AdaRoundConfig(
                policy="enabled",
                mode=AdaRoundMode.train_all,
                dataset_size=256,
                **common_kwargs,
            ).raw_dict()
        elif level == 4:
            cfg["adaround"] = AdaRoundConfig(
                policy="enabled",
                mode=AdaRoundMode.train_all,
                dataset_size=1024,
                **common_kwargs,
            ).raw_dict()
        else:
            raise ValueError(f"Got optimization_level {level} which is not supported")
        return cfg

    def get_compression_config(self, level, parameters_count, target, logger):
        user_req_compression = self.compression_level
        if target == OptimizationTarget.SAGE and (parameters_count <= MINIMUM_PARAMS_FOR_COMPRESSION):
            if user_req_compression is not None and user_req_compression > 0:
                logger.warning(
                    f"Reducing compression ratio to 0 because the number of parameters in the network "
                    f"is not large enough ({int(parameters_count / 1e6)}M and need at least {int(MINIMUM_PARAMS_FOR_COMPRESSION / 1e6)}M). "
                    f"Can be enforced using model_optimization_config(compression_params, auto_4bit_weights_ratio={level * 0.2:.3f})",
                )
            compression_ratio = 0
        else:
            compression_ratio = level * 0.2

        cfg = {
            "compression_params": CompressionConfig(auto_4bit_weights_ratio=compression_ratio).raw_dict(),
        }
        return cfg


class BaseMOConfig(BaseModel):
    research: dict = Field(None)

    class Config:
        extra = "forbid"

    @classmethod
    def keys(cls):
        return cls.__fields__.keys() - {"research"}


class ModelOptimizationLayerConfig(BaseMOConfig):
    def export(self):
        for key in self.keys():
            cfg = getattr(self, key)
            cfg.raw_dict()

    @classmethod
    def to_commands(cls, layers_cfg, exclude_defaults=False):
        commands = []
        default_cfg = json.loads(cls().json(exclude_none=True))  # TODO: remove defaults from QuantizationParams
        for feature in cls.keys():
            feature_cfg_variants = cls._get_feature_cfg_variants(layers_cfg, feature, exclude_defaults)
            for frozen_cfg, layers in feature_cfg_variants.items():
                dict_cfg = dict(frozen_cfg)
                if exclude_defaults and default_cfg[feature] == dict_cfg:
                    continue
                current_command = cls.__fields__[feature].type_.to_cmd(dict_cfg, layers)
                commands.append(current_command)
        return commands

    @classmethod
    def _get_feature_cfg_variants(cls, layers_cfg, feature, exclude_defaults):
        configuration = dict()
        for lname, lcfg in layers_cfg.items():
            lcfg_dict = getattr(lcfg, feature).raw_dict(exclude_defaults)
            lcfg_frozen = tuple(sorted(lcfg_dict.items()))

            layers_with_cfg = configuration.get(lcfg_frozen, list())
            layers_with_cfg.append(lname)
            configuration[lcfg_frozen] = layers_with_cfg

        return configuration


class ModelOptimizationConfig(BaseMOConfig):
    """
    Class to describe the model optimization config

    NOTICE: as long as we support the legacy and some weird behaviors,
    some of the finalization of the configurations occur in `quantize.py - ModelOptimizerConfig`
    There are some weird hacks there, and I'm sorry.
    We need to deprecate everything and remove this class (and change some behaviors)
    """

    globals: GlobalsMOConfig = Field(GlobalsMOConfig.get_default())
    compression_params: CompressionConfig = Field(CompressionConfig.get_default())
    calibration: CalibrationConfig = Field(CalibrationConfig.get_default())
    equalization: GlobalEqualizationConfig = Field(GlobalEqualizationConfig.get_default())
    dead_layers_removal: GlobalDeadLayersRemovalConfig = Field(GlobalDeadLayersRemovalConfig.get_default())
    bias_correction: GlobalBiasCorrectionConfig = Field(GlobalBiasCorrectionConfig.get_default())
    se_optimization: TiledSqueezeAndExciteConfig = Field(TiledSqueezeAndExciteConfig.get_default())
    adaround: AdaRoundConfig = Field(AdaRoundConfig.get_default())
    block_round_training: BlockRoundTrainingConfig = Field(BlockRoundTrainingConfig.get_default())
    finetune: FineTuneConfig = Field(FineTuneConfig.get_default())
    train_encoding: TrainEncodingConfig = Field(TrainEncodingConfig.get_default())
    layer_norm_decomposition: GlobalLayerNormDecompositionConfig = Field(
        GlobalLayerNormDecompositionConfig.get_default(),
    )
    dead_channels_removal: DeadChannelsRemovalConfig = Field(DeadChannelsRemovalConfig.get_default())
    smart_softmax_stats: SmartSoftmaxStatsConfig = Field(SmartSoftmaxStatsConfig.get_default())
    zero_static_channels: ZeroStaticChannelsConfig = Field(ZeroStaticChannelsConfig.get_default())
    switch_concat_with_add: SwtichConcatWithAddConfig = Field(SwtichConcatWithAddConfig.get_default())
    ew_add_fusing: EWAddFusingConfig = Field(EWAddFusingConfig.get_default())
    checker_cfg: CheckerConfig = Field(CheckerConfig.get_default())
    resolution_reduction: GlobalResolutionReductionConfig = Field(GlobalResolutionReductionConfig.get_default())
    activation_clipping: ActivationClippingConfig = Field(ActivationClippingConfig.get_default())
    weights_clipping: WeightsClippingConfig = Field(WeightsClippingConfig.get_default())
    layer_decomposition: GlobalLayerDecompositionConfig = Field(GlobalLayerDecompositionConfig.get_default())
    defuse: DefuseConfig = Field(DefuseConfig.get_default())
    translation_config: TranslationConfig = Field(TranslationConfig.get_default())
    precision_config: PrecisionConfig = Field(PrecisionConfig.get_default())
    mix_precision_search: MixPrecisionSearchConfig = Field(MixPrecisionSearchConfig.get_default())
    global_avgpool_reduction: GlobalAvgpoolReductionConfig = Field(GlobalAvgpoolReductionConfig.get_default())
    layers: Dict[str, ModelOptimizationLayerConfig] = Field(dict())
    add_shortcut_layer: AddShortcutConfig = Field(AddShortcutConfig.get_default())
    negative_exponent: NegExponentConfig = Field(NegExponentConfig.get_default())
    split_ew_mult_by_bit_significance: SplitEWMultByBitSignificanceConfig = Field(
        SplitEWMultByBitSignificanceConfig.get_default(),
    )
    split_fused_activation: SplitFusedActivationConfig = Field(SplitFusedActivationConfig.get_default())
    use_prequantized_weights: LoadQuantConfig = Field(LoadQuantConfig.get_default())
    conv_decomposition: ConvDecompositionConfig = Field(ConvDecompositionConfig.get_default())
    quarot: QuaRotConfig = Field(QuaRotConfig.get_default())
    conv_a16_w4: ConvA16W4Config = Field(ConvA16W4Config.get_default())

    # MatMul
    matmul_correction: MatmulCorrectionConfig = Field(MatmulCorrectionConfig.get_default())
    matmul_equalization: MatmulEqualizationConfig = Field(MatmulEqualizationConfig.get_default())
    matmul_decomposition: MatmulDecompositionConfig = Field(GlobalLayerDecompositionConfig.get_default())

    def to_commands(self, exclude_defaults=False, layers_to_exclude=None):
        default_cfg = self.__class__()
        commands = []
        for key in self.keys():
            if key == "layers":
                continue
            feature_config = getattr(self, key)
            if exclude_defaults and getattr(default_cfg, key) == feature_config:
                continue
            cmd = feature_config.to_cmd(exclude_defaults, layers_to_exclude)
            commands.append(cmd)
        per_layer_commands = ModelOptimizationLayerConfig.to_commands(self.layers, exclude_defaults)
        commands.extend(per_layer_commands)
        commands = [cmd for cmd in commands if cmd]  # remove empty entries
        return commands

    def remove_layer_from_all_configs(self, lname):
        for key in self.keys():
            feature_config = getattr(self, key)
            if key != "layers":
                feature_config.remove_layer_from_config(lname)


def update_nested(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = update_nested(d.get(k, {}), v)
        elif isinstance(v, list):
            dest = [] if d.get(k) is None else d.get(k)
            d[k] = update_nested_list(dest, v)
        else:
            d[k] = v
    return d


def update_nested_list(converted, orig):
    for elem in orig:
        if isinstance(elem, list):
            converted.append(update_nested_list([], elem))
        else:
            converted.append(elem)
    return converted
