#!/usr/bin/env python
"""High level API for quantization."""

import copy
import json
import logging
import os
import time
from pathlib import Path

import numpy as np
import tensorflow as tf
import yaml

from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerBiasCorrectionConfig
from hailo_model_optimization.acceleras.utils import hn_npz_utils
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_LEARNING_RATE,
    MAXIMUM_DESCRIPTOR_SIZE_IN_HAILORT,
    MAXIMUM_THROUGHPUT_FOR_16BIT_OUTPUT,
    SUPPORTED_LAYERS_IN_A16_W16,
    AdaRoundMode,
    BiasCorrectionPolicy,
    BiasMode,
    FeaturePolicy,
    FinetunePolicy,
    LayerFeaturePolicy,
    LayerSupportStatus,
    PrecisionMode,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.acceleras.utils.dataset_util import data_to_dataset
from hailo_model_optimization.algorithms.ada_round.ada_round import AdaRound
from hailo_model_optimization.algorithms.bias_correction.bias_correction import BiasCorrection
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_sdk_client.numeric_translator import quantize_model, set_quantized_params
from hailo_sdk_client.numeric_translator.equalization import ParamsEqualizer
from hailo_sdk_client.numeric_translator.params_sorter import ParamsSorter
from hailo_sdk_client.quantization.graph_wrapper import GraphWrapperAcceleras, GraphWrapperTfModel
from hailo_sdk_client.quantization.tools.activation_clipping_tool import ActivationClipping
from hailo_sdk_client.quantization.tools.iterative_bias_correction import IBC
from hailo_sdk_client.quantization.tools.optimize_kernel_ranges import ClipAwareParamsSorter
from hailo_sdk_client.quantization.tools.quant_aware_fine_tune import acceleras_fine_tune, fine_tune_from_feed
from hailo_sdk_client.quantization.tools.weights_clipping_tool import WeightsClipping
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import FusedStandaloneEWAddLayer, ResizeLayer
from hailo_sdk_common.hailo_nn.model_optimization.configuration_verifier import (
    apply_quantization_config_to_hn,
    verify_commands,
)
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.serialization.numpy_serialization import hailo_np_savez
from hailo_sdk_common.targets.inference_targets import (
    EmulationInferenceTargets,
    SdkFPOptimized,
    SdkNumeric,
    SdkPartialNumeric,
)


class QuantizationException(Exception):
    """Quantization exception."""


class ModelOptimizerConfig:
    """
    This object wraps the dict configuration object
    :class:`ModelOptimizationConfig <hailo_sdk_common.hailo_nn.model_optimization.configuration.ModelOptimizationConfig>`):
    it simplified access to some configuration and wraps other
    """

    def __init__(self, hn_model, logger=None):
        self._logger = logger or default_logger()
        self._hn_model = hn_model
        self._config = ModelOptimizationConfig()
        self.load_config(self._config.dict())  # Load default config

    @property
    def compression_config(self):
        return self._config.compression_params

    @property
    def global_config(self):
        return self._config.globals

    @property
    def calibration_config(self):
        return self._config.calibration

    @property
    def bias_correction_config(self):
        return self._config.bias_correction

    @property
    def equalization_config(self):
        return self._config.equalization

    @property
    def adaround(self):
        return self._config.adaround

    @property
    def finetune_config(self):
        return self._config.finetune

    @property
    def checker_config(self):
        return self._config.checker_cfg

    def get_mo_config(self):
        return self._config

    def update_model(self, model, override_config):
        self._hn_model = model
        if override_config:
            self.update_after_params_change()
        elif self._config is not None:
            self.apply_to_layers()

    def load_config(self, config, work_dir=None, calibset_size=None):
        parsed_commands = verify_commands(self._hn_model, config)
        self._handle_config(parsed_commands)
        self._finalize_config(calibset_size)
        self.save_cfg(work_dir, parsed_commands, "quantization_config.yaml")

    @staticmethod
    def save_cfg(work_dir, parsed_commands, path):
        if work_dir is not None:
            yaml_path = os.path.join(work_dir, path)
            with open(yaml_path, "w") as fp:
                yaml.dump(json.loads(parsed_commands.json()), fp)

    def has_bias_correction(self):
        if self.bias_correction_config is None:
            raise QuantizationException("ibc_cfg shouldn't be None")
        default_policy = self.bias_correction_config.policy
        by_layer_policy = self.bias_correction_config.layers
        any_enabled = any(policy == BiasCorrectionPolicy.enabled for policy in by_layer_policy)
        any_allowed = len(by_layer_policy) == 0 or any(
            policy == BiasCorrectionPolicy.allowed for policy in by_layer_policy
        )
        return any_enabled or (any_allowed and default_policy == BiasCorrectionPolicy.enabled)

    def has_finetune(self):
        return (self.finetune_config is not None) and (self.finetune_config.policy == FinetunePolicy.enabled)

    def has_adaround(self):
        layers = [layer.name for layer in self._hn_model.stable_toposort()]
        config_per_layer = AdaRound.get_config_by_layer(self._config.adaround, layers)
        for layer in layers:
            layer_cfg = config_per_layer.get(layer)
            layer_policy = LayerFeaturePolicy.allowed if layer_cfg is None else layer_cfg.policy

            default_policy = self._config.adaround.policy
            adaround_mode = self._config.adaround.mode
            if layer_policy == LayerFeaturePolicy.allowed:
                hn_layer = self._hn_model.get_layer_by_name(layer)
                if adaround_mode == AdaRoundMode.train_all:
                    policy = default_policy
                elif adaround_mode == AdaRoundMode.train_4bit:
                    if hn_layer.precision_config.precision_mode.reduce() == PrecisionMode.a8_w4:
                        policy = default_policy
                    else:
                        policy = FeaturePolicy.disabled
            else:
                policy = FeaturePolicy(layer_policy.value)

            if policy == FeaturePolicy.enabled:
                return True

        return False

    def get_active_train_algorithms_names(self):
        names = []
        if self.has_adaround():
            names.append("adaround")
        if self.has_bias_correction():
            names.append("bias correction")
        if self.has_finetune():
            names.append("finetune")
        return names

    def _handle_config(self, config):
        self._config = config
        self.apply_to_layers()

    def _finalize_config(self, calibset_size):
        self._update_calibration_batch_info(calibset_size)
        self._update_bias_correction_batch_info(calibset_size)
        self._update_checker_config_batch_info(calibset_size)
        self._update_finetune_config()

    def _get_valid_batch_info(self, calibset_size, desired_calibset_size, desired_batch_size, feature_name):
        final_calibset_size = desired_calibset_size
        if calibset_size is not None and desired_calibset_size > calibset_size:
            final_calibset_size = calibset_size
            self._logger.warning(
                f"{feature_name}:"
                f"\tDataset didn't have enough data for calibset-size of {desired_calibset_size} "
                f"\tQuantizing using calibration size of {final_calibset_size}",
            )

        final_batch_size = desired_batch_size
        if desired_batch_size > final_calibset_size:
            final_batch_size = final_calibset_size
            self._logger.warning(
                f"{feature_name}:\tBatch size was greater than calibset size, using batch size {final_batch_size}",
            )
        return final_calibset_size, final_batch_size

    def _update_calibration_batch_info(self, calibset_size):
        desired_calibset_size = self.calibration_config.calibset_size
        desired_batch_size = self.calibration_config.batch_size
        calibset_size, batch_size = self._get_valid_batch_info(
            calibset_size,
            desired_calibset_size,
            desired_batch_size,
            "Calibration",
        )
        self.calibration_config.calibset_size = calibset_size
        self.calibration_config.batch_size = batch_size

    def _update_finetune_config(self):
        desired_batch_size = self.finetune_config.batch_size
        if desired_batch_size is None:
            desired_batch_size = self.calibration_config.batch_size
        self.finetune_config.batch_size = desired_batch_size

        desired_learning_rate = self.finetune_config.learning_rate
        if desired_learning_rate is None:
            desired_learning_rate = DEFAULT_LEARNING_RATE / 8 * desired_batch_size
        self.finetune_config.learning_rate = desired_learning_rate

        desired_warmup_lr = self.finetune_config.warmup_lr
        if desired_warmup_lr is None:
            desired_warmup_lr = desired_learning_rate / 4
        self.finetune_config.warmup_lr = desired_warmup_lr

    def _update_bias_correction_batch_info(self, calibset_size):
        desired_calibset_size = self.bias_correction_config.calibset_size
        desired_batch_size = self.bias_correction_config.batch_size
        if desired_batch_size is None:
            desired_batch_size = self.calibration_config.batch_size
        if desired_calibset_size is None:
            desired_calibset_size = self.calibration_config.calibset_size
        calibset_size, batch_size = self._get_valid_batch_info(
            calibset_size,
            desired_calibset_size,
            desired_batch_size,
            "Bias Correction",
        )

        self.bias_correction_config.calibset_size = calibset_size
        self.bias_correction_config.batch_size = batch_size

    def _update_checker_config_batch_info(self, calibset_size):
        desired_calibset_size = self.checker_config.dataset_size
        desired_batch_size = self.checker_config.batch_size
        if desired_batch_size is None:
            desired_batch_size = self.calibration_config.batch_size
        if desired_calibset_size is None:
            desired_calibset_size = self.calibration_config.calibset_size
        calibset_size, batch_size = self._get_valid_batch_info(
            calibset_size,
            desired_calibset_size,
            desired_batch_size,
            "Checker Config",
        )

        self.checker_config.dataset_size = calibset_size
        self.checker_config.batch_size = batch_size

    def apply_to_layers(self):
        """
        Applies the configurations from `config.layers` to the hn layer params
        """
        apply_quantization_config_to_hn(self._hn_model, self._config)
        self.update_after_params_change()

    def update_after_params_change(self):
        """
        Updated `config.layers` parameters after the hn params has been modified
        """
        cfg = {}
        cfg.update(self._hn_model.get_per_layer_precision_config())
        cfg.update(self._hn_model.get_per_layer_translation_config())
        cfg = ModelOptimizationConfig(**cfg)
        self._config.precision_config = cfg.precision_config
        self._config.translation_config = cfg.translation_config


class ModelOptimizer:
    # TODO: Move auto4bit logic and force range logic.
    def __init__(self, hn_model, hw_arch, get_tf_graph_callback, work_dir, logger=None):
        self._logger = logger or default_logger()
        self._work_dir = work_dir
        self._hw_arch = hw_arch
        self._hn_model = hn_model
        self._config = ModelOptimizerConfig(hn_model)
        self._pre_quantization_stats = None
        self._was_equalized = None
        self._graph_wrapper = None
        self._get_tf_graph_callback = get_tf_graph_callback
        self._structural_change = False
        self._loaded_config = False
        self._dataset = None
        self._acceleras_bias_correction_attempt = False
        self._acceleras_usage = {}
        self._gpu_policy: ThreeWayPolicy = ThreeWayPolicy.allowed  # This code is legacy

    @property
    def hn_model(self):
        return self._hn_model

    def get_config(self):
        return self._config.get_mo_config()

    def save_config(self):
        if self._work_dir is None:
            return
        final_cfg = self.get_config()
        path = os.path.join(self._work_dir, "mo_cfg.alls")
        with open(path, "w") as fp:
            for line in final_cfg.to_commands():
                fp.write(line + "\n")
        self._config.save_cfg(self._work_dir, final_cfg, "post_quantization_config.yaml")

    def update_model(self, model, override_config=False):
        self._hn_model = model
        self._config.update_model(model, override_config)

    def set_dataset(self, calib_data, data_type):
        if isinstance(calib_data, dict):
            # as we changed the name of the layers to have the name of the network_name before, we need to adjust the
            # data_set we get and add the name of the model before -
            # for example: input_layer1 will be ==>> <network_name>/input_layer1
            calib_data = {self._hn_model.get_layer_by_name(key).name: value for key, value in calib_data.items()}
        dataset, image_count = data_to_dataset(calib_data, data_type, self._logger)
        self._dataset = dataset
        return image_count

    @property
    def work_dir(self):
        return self._work_dir

    @work_dir.setter
    def work_dir(self, val):
        self._work_dir = val

    @property
    def structural_change(self):
        return self._structural_change

    @property
    def config(self):
        return self._config

    @structural_change.setter
    def structural_change(self, value):
        self._structural_change = value

    def _is_acceleras_ibc_attempt(self):
        if not self._config.bias_correction_config.fast_ibc:
            return False
        else:
            return self._is_acceleras_bias_correction_bias_mode_supported()

    def _init_graph_wrapper(self, hn_data, params_data, model_config, force_tf_model):
        # TODO for now we dont use acceleras with clipping nor ibc - but in the future we will
        if not force_tf_model:
            graph_wrapper = GraphWrapperAcceleras(hn_data, params_data, model_config, self._get_tf_graph_callback)
            if not graph_wrapper.accerleras_build:
                graph_wrapper = GraphWrapperTfModel(self._get_tf_graph_callback)
            else:
                self._gpu_policy = ThreeWayPolicy.allowed
        else:
            graph_wrapper = GraphWrapperTfModel(self._get_tf_graph_callback)
        if not force_tf_model:  # currently, forced only in bias correction
            self._acceleras_usage["calibration"] = isinstance(graph_wrapper, GraphWrapperAcceleras)
        self._gpu_warning()
        return graph_wrapper

    def _hailo_np_savez(self, fname, params):
        if self._work_dir is not None:
            if isinstance(params, ModelParams):
                params = params.params
            hailo_np_savez(os.path.join(self._work_dir, fname), **params)
        else:
            self._logger.debug(f"Not saving {fname}, no work_dir passed..")

    def load_config(self, config, calibset_size=None):
        self._config.load_config(config, self._work_dir, calibset_size=calibset_size)
        self._loaded_config = True
        self._acceleras_bias_correction_attempt = self._is_acceleras_ibc_attempt()

    def _gpu_warning(self):
        gpu_policy = self._gpu_policy
        no_gpu_detected = not tf.config.list_physical_devices("GPU")
        active_train_algorithms = self._config.get_active_train_algorithms_names()
        if no_gpu_detected and (gpu_policy == ThreeWayPolicy.allowed) and active_train_algorithms:
            self._logger.warning(
                f"Running {active_train_algorithms} but no GPU was detected, expect long running time",
            )

        if (gpu_policy == ThreeWayPolicy.disabled) and active_train_algorithms:
            self._logger.warning(
                f"Running {active_train_algorithms} but GPU is disabled by policy, expect long running time",
            )

        if no_gpu_detected and (gpu_policy == ThreeWayPolicy.enabled):
            self._logger.warning("GPU policy set to enabled but no GPU detected, expect long running time")

    def init_stats_collection(self, hn_data, params_data, force_tf_model=False, **tf_graph_kwargs):
        if not self._loaded_config:
            self._logger.warning("Initializing stats collection without loading config")
        target = SdkFPOptimized(enable_clipping=True)
        model_config = self.get_config()
        self._graph_wrapper = self._init_graph_wrapper(hn_data, params_data, model_config, force_tf_model)
        use_gpu = self._gpu_policy == ThreeWayPolicy.enabled

        batch_size = self._config.calibration_config.batch_size
        calibset_size = self._config.calibration_config.calibset_size
        self._graph_wrapper.set_calib_dataset(self._dataset, self._hn_model, calibset_size, batch_size)
        self._graph_wrapper.init_graph(target, use_gpu, **tf_graph_kwargs)

    def clip_weights(self, params, weights_clipping_cfg):
        graph_target = EmulationInferenceTargets.SDK_FP_OPTIMIZED
        clipper = WeightsClipping(params, self._hn_model, weights_clipping_cfg)
        if not clipper.has_clipping:
            return params

        self._logger.info("Starting weights clipping")

        def is_supported(layer_name, param_key, values):
            return self._graph_wrapper.has_variable(graph_target, param_key)

        clipper.remove_unsupported_layers(is_supported)

        weight_clipping_params = clipper.weights_clip_params
        self._graph_wrapper.update_graph_params(graph_target, weight_clipping_params)
        params.update(weight_clipping_params)

        self._hailo_np_savez(f"{self._hn_model.name}_weights_clipping.npz", params)
        return params

    def clip_activations(self, params, clip_cfg):
        target_name = EmulationInferenceTargets.SDK_FP_OPTIMIZED

        clipper = ActivationClipping(params, self._hn_model, clip_cfg)
        if not clipper.has_clipping:
            return params
        self._logger.debug("Starting activations clipping")
        self._graph_wrapper.update_graph_params(target_name, clipper.act_clip_params)

        if clipper.has_percentile_clipping:
            hailo_export = self._graph_wrapper.get_graph(target_name)
            supported_layers = hailo_export.activations_histograms_layers_names
            clipper.remove_unsupported_layers(supported_layers)
            stats = self._graph_wrapper.collect_stats(target_name)
            params = clipper.add_histograms_params(stats)
            self._graph_wrapper.update_graph_params(target_name, params)
            histograms = self._graph_wrapper.collect_stats(
                target_name,
                layers_to_clip=clipper.percentile_clipping_layers,
                run_eagerly=True,
            )
            params = clipper.add_clipping_params(histograms)
            self._graph_wrapper.update_graph_params(target_name, params)

        self._hailo_np_savez(f"{self._hn_model.name}_activation_clipping.npz", params)
        self._logger.debug("Activations clipping is done")
        return params

    def sort_params(self, params, clip_aware_sort=False, update_graph=True):
        """
        sort the params before

        Args:
            params : _description_
            clip_aware_sort (bool, optional):  Defaults to False.
            update_graph (bool, optional):  Defaults to True. a bool to indicate if we want to update the
            _graph_wrapper.

        Returns:
            params: the new sorted params

        """
        if (self._was_equalized is not None) and (not self._was_equalized or clip_aware_sort):
            # skip sorting after equalization block,
            # if either equalization wasn't applied or clip aware sort algorithm is True
            return params

        layers_with_groups = []
        for layer in self._hn_model.stable_toposort():
            qgroups = layer.precision_config.quantization_groups
            if (qgroups is not None) and (qgroups > 1):
                layers_with_groups.append(layer)

        sort_policy = self._config.global_config.sort_params
        if (
            sort_policy == ThreeWayPolicy.disabled
            or sort_policy == ThreeWayPolicy.allowed
            and len(layers_with_groups) == 0
        ):
            return params
        # else: enabled or allowed & len(layers_with_groups) > 0

        self._logger.debug("Starting parameters sorting")
        if clip_aware_sort:
            self._logger.debug("Running Clip-Aware Params Sorter before fine-tuning")
            sorter = ClipAwareParamsSorter(self._hn_model)
            sorted_params = sorter.sort_params(params)
        else:
            sorter = ParamsSorter(self._hn_model)
            sorted_params = sorter.sort_params(params)
        self._logger.debug("Parameters sorting is done")

        if sorted_params is None:
            return params
        self._structural_change = True
        params.update(sorted_params)

        # here we update the graph wrapper,
        if update_graph:
            self._graph_wrapper.update_graph_params(EmulationInferenceTargets.SDK_FP_OPTIMIZED, params)
        self._hailo_np_savez(f"{self._hn_model.name}_sorted_params.npz", params)

        return params

    # TODO: move this utility function to the abstract algorithm class
    @staticmethod
    def print_algo_time(start, end, task):
        hours, rem = divmod(end - start, 3600)
        minutes, seconds = divmod(rem, 60)
        default_logger().info(f"{task} completion time {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}")

    # Pre-process
    def equalize(self, params, executable_model):
        target = EmulationInferenceTargets.SDK_FP_OPTIMIZED
        equalize_cfg = self._config.equalization_config
        self._logger.info("Starting Equalization")
        inference_results = self._graph_wrapper.collect_stats(target)

        start_time = time.time()
        params, self._was_equalized = self.simple_equalization(
            self._hn_model,
            equalize_cfg,
            params,
            inference_results,
            executable_model,
        )
        end_time = time.time()

        if self._was_equalized:
            self.print_algo_time(start_time, end_time, "equalization")
            # update graph params
            self._graph_wrapper.update_graph_params(target, params)
            self._hailo_np_savez(f"{self._hn_model.name}_equalized_params.npz", params)
            self._logger.info("Equalization is done")
        else:
            self._logger.info("Equalization skipped")
        return params

    @staticmethod
    def simple_equalization(hn_model, equalize_cfg, params, inference_results, executable_model):
        executable_model.get_finalized_conv_layers_inference(inference_results)
        params_equalizer = ParamsEqualizer(hn_model, equalize_cfg)
        params = params_equalizer.equalize_model(params, executable_model.conv_layers_inference)
        params.update(params.params)
        save_factors = False
        if save_factors:
            np.savez("factors_for_equalization", **params_equalizer._factors)  # TODO -change
        return params, params_equalizer.was_equalized

    def bit_reduction(
        self,
        params,
        executable_model,
        force_results_by_layer=None,
        previous_statistics=None,
        debug_precise_mode=False,
        reuse_stats=False,
    ):
        self._logger.info(f"Translating params for {self._hn_model.name}")
        if isinstance(previous_statistics, ModelParams):
            previous_statistics = previous_statistics.params

        if reuse_stats and (self._pre_quantization_stats is not None):
            results_by_layer = self._pre_quantization_stats
        elif not reuse_stats:
            results_by_layer = self._collect_pre_quantization_stats(force_results_by_layer)
        else:
            raise QuantizationException("Trying to reuse stats, but none were previously collected")
        executable_model.get_finalized_conv_layers_inference(results_by_layer)

        max_elementwise_feed_repeat = self._config.global_config.max_elementwise_feed_repeat
        statistics = quantize_model.quantize_model(
            self._hw_arch,
            params,
            executable_model.conv_layers_inference,
            self._hn_model,
            previous_statistics,
            debug_precise_mode,
            is_apu_2s_complement=True,
            max_elementwise_feed_repeat=max_elementwise_feed_repeat,
        )

        # Update params with quantization results
        translated_params = set_quantized_params.set_quantized_params(params, statistics)

        self._pre_quantization_stats = results_by_layer
        return translated_params

    def _collect_pre_quantization_stats(self, force_results_by_layer=None, path=None):
        target = EmulationInferenceTargets.SDK_FP_OPTIMIZED
        if force_results_by_layer is None:
            results_by_layer = self._graph_wrapper.collect_stats(target)
        else:
            results_by_layer = force_results_by_layer

        self._force_stats_per_layer(results_by_layer)
        path = path if (path is not None) else f"{self._hn_model.name}_layer_stats.npz"
        self._hailo_np_savez(path, results_by_layer)
        return results_by_layer

    # TODO: Move this logic to a dedicated class
    def _force_stats_per_layer(self, results_by_layer):
        for layer in self._hn_model.stable_toposort():
            force_range_in = layer.translation_config.force_range_in
            force_range_out = layer.translation_config.force_range_out
            if force_range_in is not None:
                inp_min_val, inp_max_val = force_range_in
                name = layer.name
                inp_min_key = f"{name}/stats_min_inp:0"
                inp_max_key = f"{name}/stats_max_inp:0"
                inp_was_forced_key = f"{name}/limvals_inp_forced:0"
                if inp_min_key not in results_by_layer or inp_max_key not in results_by_layer:
                    self._logger.warning(
                        f"Tried forcing input values on layer {layer.name} without stats, it has no effect",
                    )
                results_by_layer[inp_min_key] = inp_min_val
                results_by_layer[inp_max_key] = inp_max_val
                results_by_layer[inp_was_forced_key] = True
            if force_range_out is not None:
                out_min_val, out_max_val = force_range_out
                name = layer.name
                out_min_key = f"{name}/stats_min_out:0"
                out_max_key = f"{name}/stats_max_out:0"
                out_was_forced_key = f"{name}/limvals_out_forced:0"
                if out_min_key not in results_by_layer or out_max_key not in results_by_layer:
                    self._logger.warning(
                        f"Tried forcing output values on layer {layer.name} without stats, it has no effect",
                    )
                results_by_layer[out_min_key] = out_min_val
                results_by_layer[out_max_key] = out_max_val
                results_by_layer[out_was_forced_key] = True

    def adaround(self, npz, qnpz):
        feature_enabled = self._config.adaround.policy == FeaturePolicy.enabled
        any_layer_enabled = any(
            layer_cfg.policy == LayerFeaturePolicy.enabled for layer_cfg in self._config.adaround.layers.values()
        )
        if not (feature_enabled or any_layer_enabled):
            return None
        # TODO: reuse collection graph
        try:
            acceleras_model = self._initialize_acceleras_model(npz, qnpz, quantize=False)
        except AccelerasImplementationError:
            self._logger.exception("Adaround was cancelled, acceleras model creation failed")
            raise

        base_work_dir = "" if self._work_dir is None else self._work_dir
        work_dir = os.path.join(base_work_dir, "adaround")

        # TODO: add per layer config
        # TODO: pass the model config

        try:
            algo = AdaRound(
                acceleras_model,
                self._config.get_mo_config(),
                logging.INFO,
                dataset=self._dataset,
                work_dir=work_dir,
                logger=self._logger,
            )
            algo.run()
        except AccelerasImplementationError:
            self._logger.exception("Adaround was cancelled, encountered acceleras error during inference")
            raise
        adaround_npz = hn_npz_utils.NpzWrap(dict(npz).copy())
        self._eagerize_model(acceleras_model)
        for layer_name, ac_layer in acceleras_model.layers.items():
            adaround_npz.write_kernel(layer_name, ac_layer)
            adaround_npz.write_bias(layer_name, ac_layer)
        adaround_params = ModelParams(adaround_npz.params)

        self._hailo_np_savez(f"{self._hn_model.name}_optimized_adaround", adaround_params)
        return adaround_params

    def _eagerize_model(self, model):
        for dt, _ in self._dataset.batch(1).take(1):
            model(dt)

    def _initialize_acceleras_model(self, npz, qnpz=None, quantize=True):
        npz = npz.params if npz is not None else npz
        qnpz = qnpz.params if qnpz is not None else qnpz

        hn = self._hn_model.to_hn(self._hn_model.name, json_dump=False)
        model_config = self.get_config()

        acceleras_model = HailoModel(hn)
        acceleras_model.import_weights(npz)
        CreateMixedPrecision(model=acceleras_model, model_config=model_config, logger_level=0).run()  # load config
        if qnpz is not None:
            acceleras_model.import_hw_params_from_qnpz(qnpz, force_legacy=True)
            if quantize:
                acceleras_model.set_quantized()
        return acceleras_model

    # Post process
    def bias_correction(self, update_params_layer_bias_callback, npz, qnpz, executable_model):
        build_tf_model = not self._acceleras_bias_correction_attempt
        params = None
        if not self._config.has_bias_correction():
            self._logger.debug("Skipping bias correction")
            return None
        self._logger.debug("Starting bias correction")
        if not build_tf_model:
            try:
                self._logger.debug("\tUsing acceleras")
                npz = npz.params if npz is not None else npz
                qnpz = qnpz.params if qnpz is not None else qnpz
                params = self._acceleras_bias_correction(npz, qnpz, executable_model)
            except AccelerasImplementationError:
                self._logger.debug("Failed bias correction with acceleras, falling back to tf_model")
                build_tf_model = True
        if build_tf_model:
            params = self._tf_model_bias_correction(update_params_layer_bias_callback)
        self._acceleras_usage["bias_correction"] = not build_tf_model

        self._hailo_np_savez(f"{self._hn_model.name}_translated_ibc_params", params)
        self._logger.debug("Finished bias correction")

        return params

    def _tf_model_bias_correction(self, update_params_layer_bias_callback):
        ibc_cfg = self._config.bias_correction_config
        target = SdkPartialNumeric() if ibc_cfg.fast_ibc else SdkNumeric()
        use_gpu = self._gpu_policy != ThreeWayPolicy.disabled
        main_graph_wrapper = self._graph_wrapper
        self.init_stats_collection(None, None, True)
        self._graph_wrapper.init_graph(target, use_gpu)
        self._graph_wrapper.init_graph(SdkFPOptimized(enable_clipping=True), use_gpu)

        native_export = self._graph_wrapper.get_graph(EmulationInferenceTargets.SDK_FP_OPTIMIZED)
        numeric_export = self._graph_wrapper.get_graph(target.name)

        ibc = IBC(self._hn_model, native_export, numeric_export, ibc_cfg)
        # TODO: I don't like this callbacks. It calls sdk_backend behind the scenes.
        #  The executable model should be exported / separated from sdk_backend
        native_initializer = self._graph_wrapper.get_initializer(EmulationInferenceTargets.SDK_FP_OPTIMIZED)
        numeric_initializer = self._graph_wrapper.get_initializer(target.name)
        batch_count = self._config.calibration_config.calibset_size // self._config.calibration_config.batch_size
        params = ibc.run(native_initializer, numeric_initializer, batch_count, update_params_layer_bias_callback)
        self._graph_wrapper = main_graph_wrapper
        return params

    def _acceleras_bias_correction(self, npz, qnpz, executable_model):
        self._logger.debug("Starting bias correction")
        hn = self._hn_model.to_hn(self._hn_model.name, json_dump=False)
        model_config = self.get_config()
        # TODO: reuse collection graph
        acceleras_model_quant = HailoModel(hn)
        acceleras_model_quant.import_config(model_config)
        acceleras_model_quant.import_weights(npz)
        acceleras_model_quant.import_hw_params_from_qnpz(qnpz, force_legacy=True)
        base_work_dir = "" if self._work_dir is None else self._work_dir
        work_dir = os.path.join(base_work_dir, ".bias_correction")
        bias_correction = BiasCorrection(
            acceleras_model_quant,
            model_config,
            self._dataset,
            work_dir=work_dir,
            logger_level=logging.INFO,
            logger=self._logger,
        )
        bias_correction.run()

        bc_npz = hn_npz_utils.NpzWrap(dict(npz).copy())
        bc_npz.write_all_biases(acceleras_model_quant)
        fp_opt = ModelParams(bc_npz.params)
        # Re-quantize after bias correction to reduce quantization noise from acceleras vs quantize model behavior
        return self.bit_reduction(fp_opt, executable_model)

    def _is_acceleras_bias_correction_bias_mode_supported(self):
        """
        Acceleras doesn't support single_scale_decomposition properly yet, there are still few bugs with it.
        Therefore, this function raises AccelerasImplementationError if the model has single_scale_decomposition bias.

        Args:

        """
        unsupported_layers = set()
        for layer in self._hn_model.stable_toposort():
            if layer.ibc_supported() != LayerSupportStatus.supported:
                continue
            prec_cfg = layer.precision_config
            ibc_cfg = self._config.bias_correction_config.layers.get(
                layer.name,
                LayerBiasCorrectionConfig.get_default(),
            )
            if ibc_cfg.policy == BiasCorrectionPolicy.disabled:
                continue
            if (self._config.bias_correction_config.policy == BiasCorrectionPolicy.disabled) and (
                ibc_cfg.policy == BiasCorrectionPolicy.allowed
            ):
                continue
            if prec_cfg.bias_mode != BiasMode.double_scale_initialization:
                unsupported_layers.add(layer.name)

        if unsupported_layers:
            self._logger.debug(f"The following layers' bias is unsupported by bias correction: {unsupported_layers}")
            return False

        return True

    def finetune(self, backend, force_results_by_layer=None, build_tf_model=False):
        ft_cfg = self._config.finetune_config
        if ft_cfg.policy == FinetunePolicy.disabled:
            self._logger.debug("Skipping FineTune")
            return

        train_dataset = self._dataset
        self._logger.debug("Starting FineTune")
        # TODO: remove runner usage here. Params should be changed in pipeline.

        params_pre_quantization = copy.deepcopy(backend.get_params_pre_quantization())
        if self._pre_quantization_stats is None:
            results_by_layer = self._collect_pre_quantization_stats(force_results_by_layer)
        else:
            results_by_layer = self._pre_quantization_stats
        use_gpu = self._gpu_policy != ThreeWayPolicy.disabled
        if not build_tf_model:
            try:
                self._logger.debug("Starting finetune with acceleras")
                model_config = self.get_config()
                acceleras_fine_tune(backend, train_dataset, ft_cfg, model_config)
            except AccelerasImplementationError as e:
                self._logger.debug(f"Failed finetune with acceleras, {e}")
                self._logger.debug("Falling back to tf_model")
                build_tf_model = True
        if build_tf_model:
            ft_cfg.info_config()  # the config is logged in the base_algorithm of acceleras
            fine_tune_from_feed(
                backend,
                train_dataset,
                ft_cfg,
                results_by_layer,
                ft_cfg.batch_size,
                work_dir=self._work_dir,
                use_gpu=use_gpu,
            )
            if self._work_dir is not None:
                work_dir = Path(self._work_dir)
                tf_model_infer_path = work_dir.joinpath(f"{self._hn_model.name}_tf_model_infer")
                tf_model_infer_path.touch()

        self._acceleras_usage["finetune"] = not build_tf_model

        # Create some specialized intermediate products:
        # 1. fine-tuned "quantized" params (for quant evaluation)
        self._hailo_np_savez(f"{self._hn_model.name}_finetuned.q.npz", backend.get_params_translated())
        # 2. fine-tuned "non-quantized" params (for fake-quant evaluation)
        self._hailo_np_savez(f"{self._hn_model.name}_finetuned.f.npz", backend.get_params_pre_quantization())

        # 3. post-finetune layer statistics, for future analysis of damage control.
        tf.compat.v1.reset_default_graph()
        self._collect_pre_quantization_stats(
            force_results_by_layer,
            path=f"{self._hn_model.name}_layer_stats_post_ft.npz",
        )
        backend.load_params(params_pre_quantization)
        self._logger.debug("FineTune is done")

    def optimize_layers_to_4_bit(self):
        weights_ratio = self._config.compression_config.auto_4bit_weights_ratio
        if weights_ratio > 0:
            self._logger.info("Starting auto 4bit weights")
            algo = Auto4Bit(self._hn_model, weights_ratio, self._logger)
            start_time = time.time()
            algo.run()
            end_time = time.time()
            self._logger.info(f"Ratio of weights in 4bit is {algo.compression_ratio:.2f}")
            self.print_algo_time(start_time, end_time, "auto4bit")
            self._logger.info("Auto 4bit weights is done")

    @staticmethod
    def _get_model_precision_config_meta(layer):
        """
        return model precision_config 'meta' information of 'precision_mode' if exists otherwise return None.
        """
        model_precision_config_meta = layer.precision_config.dict()["meta"]
        if model_precision_config_meta is None:
            model_precision_config_meta = {}
        return model_precision_config_meta.get("precision_mode", None)

    def _find_output_with_minimum_shape(self, hn_model):
        """Find the output layer with the smallest shape and return it"""
        min_shape = float("inf")
        out_layer = None
        for layer in hn_model.get_output_layers():
            # We consider layers if they have:
            #   1. Precision mode of a8_w8
            #   2. Supported ops in predecessors (for example, PPU layers don't support 16bit output)
            #   3. Precision mode that wasn't set by the user (from the alls).
            #      In case of alls configuration it should appear in the `meta` configuration
            #   4. If the output layer is splitted (for example, for another layer)
            #   5. If the output is smaller than 2M (descriptors boundary in hailort)
            layer_shape = np.prod(layer.output_shape[1:])
            if (
                layer.precision_config.precision_mode == PrecisionMode.a8_w8
                and all(x.op.name in SUPPORTED_LAYERS_IN_A16_W16 for x in list(hn_model.predecessors(layer)))
                and self._get_model_precision_config_meta(layer) is None
                and all(len(list(hn_model.successors(x))) == 1 for x in list(hn_model.predecessors(layer)))
                and layer_shape < MAXIMUM_DESCRIPTOR_SIZE_IN_HAILORT
            ) and layer_shape < min_shape:
                min_shape = layer_shape
                out_layer = layer
        return out_layer

    def _get_output_bytes_per_frame(self, hn_model):
        """ "Get the number of output bytes for each frame (total of all outputs)"""
        total_output_shape = 0
        for layer in hn_model.get_output_layers():
            if layer.precision_config.precision_mode in [
                PrecisionMode.a16_w16,
                PrecisionMode.a16_w16_a16,
                PrecisionMode.a16_w16_a8,
            ]:
                total_output_shape += 2 * np.prod(layer.output_shape[1:])
            else:
                total_output_shape += np.prod(layer.output_shape[1:])
        return total_output_shape

    def set_16bit_output(self):
        """
        Gradually setting 16bit activation on the output layers (smaller first)
        until exceeding the predefined maximum throughput
        """
        output_16bit_policy = self._config.global_config.output_16bit
        if output_16bit_policy == ThreeWayPolicy.disabled:
            return
        hn_model = self._hn_model
        total_output = self._get_output_bytes_per_frame(hn_model)
        while total_output < MAXIMUM_THROUGHPUT_FOR_16BIT_OUTPUT:
            min_out_layer = self._find_output_with_minimum_shape(hn_model)
            if (
                min_out_layer is not None
                and total_output + np.prod(min_out_layer.output_shape[1:]) < MAXIMUM_THROUGHPUT_FOR_16BIT_OUTPUT
            ):
                # no bias for output layer but we set it to avoid a warning
                min_out_layer.precision_config.precision_mode = PrecisionMode.a16_w16
                min_out_layer.precision_config.bias_mode = BiasMode.single_scale_decomposition
                self._logger.info(f"Assigning 16bit activation to output layer {min_out_layer.name}")
                total_output = self._get_output_bytes_per_frame(hn_model)
            else:
                break
        self._hn_model = hn_model

    def set_16bit_for_entire_network(self):
        for layer in self._hn_model.stable_toposort():
            precision_mode = PrecisionMode.a16_w16
            bias_mode = (
                BiasMode.double_scale_initialization
                if self._is_bias_double_scale_initialization_needed(layer)
                else BiasMode.single_scale_decomposition
            )
            layer.precision_config.precision_mode = precision_mode
            layer.precision_config.bias_mode = bias_mode

    def verify_16bit_bias_mode(self):
        for layer in self._hn_model.stable_toposort():
            precision_mode = layer.precision_config.precision_mode
            if precision_mode in [PrecisionMode.a16_w16, PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8]:
                bias_mode = (
                    BiasMode.double_scale_initialization
                    if self._is_bias_double_scale_initialization_needed(layer)
                    else BiasMode.single_scale_decomposition
                )
                layer.precision_config.precision_mode = precision_mode
                layer.precision_config.bias_mode = bias_mode

    def _is_bias_double_scale_initialization_needed(self, layer):
        if isinstance(layer, ResizeLayer) and layer._method.value == "bilinear":
            return True
        if isinstance(layer, FusedStandaloneEWAddLayer):
            return True
        return False


class Auto4Bit:
    """
    This class is used to automatically assign 4bit weights.
    """

    def __init__(self, hailo_nn, weights_ratio, logger=None):
        self._hailo_nn = hailo_nn
        self._weights_ratio = weights_ratio
        self._logger = logger or default_logger()
        self.total_compressed_weight = 0
        self._compressed_layers = []
        self._total_weights = np.sum([layer.weights for layer in self._hailo_nn.stable_toposort()])

    @property
    def compressed_layers(self):
        """The layers that we compressed using the Auto4Bit"""
        return self._compressed_layers

    @property
    def compression_ratio(self):
        """Ratio of weights compressed to 4bit"""
        return self.total_compressed_weight / self._total_weights

    def _assign_4bit_to_layers(self, layers_to_change):
        """Assign 4bit weights to all layers in layers_to_change"""
        for layer in layers_to_change:
            layer.precision_config.precision_mode = PrecisionMode.a8_w4
            self._compressed_layers.append(layer.name)
            self._logger.info(
                f"Assigning 4bit weights to layer {layer.name} with {(layer.weights / 1e3):.2f}k parameters",
            )

    def _search_layers(self, start_layers, stop_cond: callable, reverse=False):
        """
        Search layers that satisfy the stop condition using a BFS search

        Args:
            start_layers: the entry points of the bfs search
            stop_cond: a boolean function that checks if a layer matches the search
            reverse: boolean, if true the BFS scans the graph backwards

        Return:
            list of layers that satisfied the stop condition

        """
        bfs_queue = list(start_layers)
        cond_layers = []
        bfs_handled = set()
        while len(bfs_queue) != 0:
            curr_layer = bfs_queue.pop(0)
            if curr_layer in bfs_handled:
                continue
            bfs_handled.add(curr_layer)
            if stop_cond(curr_layer):
                cond_layers.append(curr_layer)
                continue
            if not reverse:
                curr_extension = list(self._hailo_nn.successors(curr_layer))
            else:
                curr_extension = list(self._hailo_nn.predecessors(curr_layer))
            bfs_queue.extend(curr_extension)
        return cond_layers

    def _get_sorted_layers_by_weights(self):
        """Return a list of layers ordered by their size"""
        layers_info_weights = [(layer, layer.weights) for layer in self._hailo_nn.stable_toposort()]
        layers_info_weights.sort(key=lambda x: x[1], reverse=True)
        return layers_info_weights

    def _get_layers_to_compress(self, supported_ops, ignored_layers):
        """
        Get list of layers to compress

        Args:
            supported_ops (set): supported operations for compression
            ignored_layers (list): list of layers to ignore

        Returns:
            list: layers to compress

        """
        layer_not_considered = []
        layers_to_change = []
        layers_info_weights = self._get_sorted_layers_by_weights()
        weight_to_remove = self._weights_ratio * self._total_weights

        def _can_be_compressed(layer):
            return (
                layer.op in supported_ops
                and layer not in ignored_layers
                and layer.precision_config.precision_mode == PrecisionMode.a8_w8
            )

        for weight_info in layers_info_weights:
            layer = weight_info[0]
            weight_compressed = weight_info[1]
            if _can_be_compressed(layer):
                layers_to_change.append(layer)
                self.total_compressed_weight += weight_compressed
            elif layer.precision_config.precision_mode.reduce() == PrecisionMode.a8_w4:
                self.total_compressed_weight += weight_compressed
            else:
                layer_not_considered.append(layer)
            if self.total_compressed_weight >= weight_to_remove:
                if len(layer_not_considered) > 0:
                    self._logger.debug(
                        f"The following layers were not considered {[layer.name for layer in layer_not_considered]}",
                    )
                return layers_to_change

        maximal_ratio_reduction = self.total_compressed_weight / self._total_weights
        self._logger.warning(
            f"The maximal 4bit weight ratio that can be achieved is {maximal_ratio_reduction:.2f} "
            f"but was given {self._weights_ratio}.",
        )
        return layers_to_change

    def _get_layers_with_user_config_precision_mode(self):
        """Get all the layers with user precision_mode configuration"""
        user_cfg_layers = []
        for layer in self._hailo_nn.stable_toposort():
            if layer.precision_config.dict()["meta"] is not None:
                user_cfg_layers.append(layer)
        return user_cfg_layers

    def run(self):
        """
        the algorithm will greedily compress at least the ratio given of original weights memory.

        This will be done by assigning 4b layers to some layers, starting with the heaviest ones
        (which are assumed to also be most redundant thus less sensitive) and skipping non-conv/dense layers
        and input/output layers (which are expected to be sensitive in many case).
        Layers that were manually set to 4bit or 16bit are skipped as well.
        """
        # supported operations
        supported_ops = {LayerType.conv, LayerType.dense}

        # ignored layers (all the layers configured by the user and the first and last layers of the model)
        def stop_cond(layer):
            return layer.op in supported_ops

        input_layers = self._hailo_nn.get_input_layers()
        first_layers = self._search_layers(input_layers, stop_cond)
        output_layers = self._hailo_nn.get_output_layers()
        last_layers = self._search_layers(output_layers, stop_cond, reverse=True)
        user_config_layers = self._get_layers_with_user_config_precision_mode()
        ignored_layers = first_layers + last_layers + user_config_layers

        layers = self._get_layers_to_compress(supported_ops, ignored_layers)
        self._assign_4bit_to_layers(layers)
