import logging
import os
import tempfile
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union

import yaml

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import PostProcessConfig
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.native_layers.base_native_layer import BaseNativeLayer
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DistributionStrategy,
    LayerType,
    OpStates,
    OptimizationTarget,
    PrecisionMode,
    ResolutionReductionStage,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasUnsupportedError
from hailo_model_optimization.acceleras.utils.dataset_util import DatasetContianer, data_to_dataset
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    AlgoModificationTracker,
)
from hailo_model_optimization.algorithms.ada_round.ada_round_v2 import AdaRound
from hailo_model_optimization.algorithms.add_shortcut_layer.add_shortcut_layer import AddShortcutLayer
from hailo_model_optimization.algorithms.algorithm_base import AlgoResults, AlgorithmBase, AlgorithmStatus
from hailo_model_optimization.algorithms.apu_neg_mantissa_correction import ApuNegMantissaCorrection
from hailo_model_optimization.algorithms.bias_correction.bias_correction_v2 import BiasCorrection
from hailo_model_optimization.algorithms.block_round_training.block_round_training import BlockRoundTraining
from hailo_model_optimization.algorithms.clip_activation_stats.clip_activation_stats import ClipActivationStats
from hailo_model_optimization.algorithms.conv_decomposition.conv_decomposition import ConvDecomposition
from hailo_model_optimization.algorithms.create_encoding.create_io_encoding import CreateIOEncoding
from hailo_model_optimization.algorithms.create_hw_params.create_hw_params import CreateHWParamsWithMatch
from hailo_model_optimization.algorithms.dead_layers_removal import DeadLayersRemoval
from hailo_model_optimization.algorithms.decompose_16bits.create_decompose_16bits import CreateDecompose16Bits
from hailo_model_optimization.algorithms.decompose_16bits.switch_decompose_16bits import SwitchDecompose16Bits
from hailo_model_optimization.algorithms.decompose_channel_wise.decompose_channel_wise import (
    DecomposeChannelWiseQuantization,
)
from hailo_model_optimization.algorithms.decompose_conv_a16_w4.decompose_conv_a16_w4 import DecomposeConvA16W4
from hailo_model_optimization.algorithms.deequalize import Deequalize
from hailo_model_optimization.algorithms.equalization.equalization import Equalization
from hailo_model_optimization.algorithms.ew_add_fusing import EWAddFusing
from hailo_model_optimization.algorithms.finetune.qft import QftRunner
from hailo_model_optimization.algorithms.finetune.qft_encoding import QftEncodingRunner
from hailo_model_optimization.algorithms.fix_zp_comp_encoding.fix_zp_comp_encoding import FixZpCompEncoding
from hailo_model_optimization.algorithms.force_preact_stats.force_preact_stats import ForcePreactStats
from hailo_model_optimization.algorithms.global_avgpool_reduction import GlobalAvgpoolReduction
from hailo_model_optimization.algorithms.hailo_layer_noise_analysis import HailoQuantAnalyzer
from hailo_model_optimization.algorithms.layer_decompose.layer_decompose_algo import LayerDecompose
from hailo_model_optimization.algorithms.matmul_correction.matmul_correction import MatmulCorrection
from hailo_model_optimization.algorithms.matmul_decompose.matmul_decompose_algo import DecomposeMatmul
from hailo_model_optimization.algorithms.matmul_decompose.matmul_decompose_scale_fix_algo import DecomposeMatmulFix
from hailo_model_optimization.algorithms.matmul_equalization.matmul_equalization import MatmulEqualization
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.mixed_precision.mixed_precision import MixedPrecision
from hailo_model_optimization.algorithms.output_16bit_as_8bit.output_16bit_as_8bit import Output16BitAs8Bit
from hailo_model_optimization.algorithms.quant_checker.quant_checker import QuantChecker
from hailo_model_optimization.algorithms.quarot.quarot import QuaRot
from hailo_model_optimization.algorithms.resolution_reduction import ResolutionReduction
from hailo_model_optimization.algorithms.smart_softmax_stats.create_softmax_mask import CreateSoftmaxMask
from hailo_model_optimization.algorithms.smart_softmax_stats.optimize_softmax_bias import OptimizeSoftmaxBias
from hailo_model_optimization.algorithms.smart_softmax_stats.smart_softmax_stats import SmartSoftmaxStats
from hailo_model_optimization.algorithms.split_ew_mult_by_bit_significance.split_ew_mult_by_bit_significance import (
    SplitEWMultByBitSignificance,
)
from hailo_model_optimization.algorithms.split_precision_norm_layer.decompose_norm_layer_algo_v2 import (
    DecomposeLayerNorm,
)
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector
from hailo_model_optimization.algorithms.switch_concat_with_add import SwitchConcatWithAdd
from hailo_model_optimization.algorithms.switch_layers.switch_layers_algo import SwitchLayersByPrecision
from hailo_model_optimization.algorithms.use_pre_quant_weights import UsePreQuantWeights
from hailo_model_optimization.algorithms.zero_static_channels.zero_static_channel_algo import ZeroStaticChannelsAlgo
from hailo_model_optimization.flows.utils.flow_memento import (
    AccelerasMemento,
    FlowMemento,
    OptFlowOriginator,
    save_acceleras_model,
)
from hailo_model_optimization.tools.orchestator import flow_control_method
from hailo_model_optimization.tools.subprocess_wrapper import BaseSubprocessFlow, SupProcessPolicies, subprocess_wrapper


class SupportedStops(Enum):
    NONE = "none"
    STATS = "_collect_stats"
    QUANTIZE = "step1"
    DEEQUALIZE = "_deequalize"


class OptimizationFlow(BaseSubprocessFlow[FlowMemento]):
    """
    a class that is in charge of the full optimization flow:
    pre_quantization_structural()
    pre_quantization_optimization()
    core_quantization()
    post_quantization_optimization(force_results_by_layer)

    Args:
                model: Mutable, the flow may change the model
                model_config: dict - the model optimization config in
                unbatched_dataset: unbatched data set for the model
                unbatched_train_dataset:
                logger: the logger to use
                work_dir: work dir to save params and hn


    """

    _flow_results: List[AlgoResults] = []

    def __init__(
        self,
        hn: dict,
        params: dict,
        model_script: ModelOptimizationConfig,
        data_container: DatasetContianer,
        optimization_target: Union[OptimizationTarget, str] = None,
        logger: logging.Logger = None,
        work_dir: Optional[str] = None,
        nms_config: Optional[PostProcessConfig] = None,
        params_statistics: Optional[dict] = None,
        adapter_name: Optional[str] = None,
    ):
        if optimization_target is None:
            raise AccelerasUnsupportedError("Optimization target is not defined")

        self._logger = logger or default_logger()
        self._hn = hn
        self._fp_params = params
        self._original_hn = deepcopy(hn)
        self._nms_config = nms_config
        self._data_container = data_container
        self.optimization_target = OptimizationTarget(optimization_target)
        self._work_dir = Path(work_dir) if work_dir is not None else None
        self._params_statistics = dict() if params_statistics is None else params_statistics

        self._quant_params = None
        self._acceleras_params = None
        self.original_input_shapes = dict()
        self._adapter_name = adapter_name

        # These objects are initialized in a lazy manner, to make sure tf is only initialized on use
        self._model: HailoModel = None
        self._fp_model: HailoModel = None
        self._dataset = None
        self._fp_hn: dict = None

        self._parsed_config = model_script
        self._mo_fp_config = deepcopy(model_script)

        self._modifications_meta_data = AlgoModificationTracker()

        # Funtional attributes
        self._step_index = None
        self._model_memento = None

        # External Funtionality
        self.serializer = OptFlowOriginator(self)

    @property
    def flavors_info(self):
        return self._flavors_info

    @property
    def parsed_config(self):
        return ModelOptimizationConfig(**self._parsed_config.dict())

    @property
    def mo_fp_config(self):
        return self._mo_fp_config

    @mo_fp_config.setter
    def mo_fp_config(self, value):
        self._mo_fp_config = value

    @property
    def model(self) -> HailoModel:
        if self._model is None:
            raise ValueError("Trying to access optimization_flow.model before a model has been initialized")
        return self._model

    @property
    def dataset(self):
        if self._dataset is None:
            self._dataset, _ = data_to_dataset(self._data_container.data, self._data_container.data_type)
        return self._dataset

    @property
    def flow_results(self):
        return self._flow_results

    @property
    def params_statistics(self):
        if len(self._params_statistics.keys()) == 0:
            return None
        return self._params_statistics

    @params_statistics.setter
    def params_statistics(self, value):
        self._params_statistics = dict() if value is None else value

    @property
    def modifications_meta_data(self) -> AlgoModificationTracker:
        return self._modifications_meta_data

    @property
    def flow_policies(self) -> SupProcessPolicies:
        """Please holder for flow policies"""
        policies = SupProcessPolicies(
            multiproc_policy=self._parsed_config.globals.multiproc_policy,
            gpu_policy=self._parsed_config.globals.gpu_policy,
        )
        return policies

    def build_model(self, force=False):
        """
        Build a model if self._model is not initialized
        """

        force = force or self.dist_info.tf_strategy and self.dist_info.call_counter == 0

        if isinstance(self._model_memento, AccelerasMemento):
            model = self.serializer.model_serializer.restore(self._model_memento)

            # !Important The memento will be consume there for will be deleted.
            self._model_memento.delete()
            self._model_memento = None

        elif self._model is None:
            model: HailoModel = HailoModel(
                self._hn,
                nms_config=self._nms_config,
                optimization_target=self.optimization_target,
                lora_adapter_name=self._adapter_name,
            )
            if self._acceleras_params is not None:
                model.import_acceleras(self._acceleras_params)
            else:
                model.import_weights(self._fp_params)

            # prepare resolution reduction info and save original input shapes
            reduction_cfg = self._parsed_config.resolution_reduction
            if reduction_cfg.layers or reduction_cfg.shape:
                _, model_input_shapes = model.resolution_reduction_prepare()
                if self.original_input_shapes == {}:
                    self.original_input_shapes = model_input_shapes

            if OpStates.QUANTIZED in model.supported_states:
                model.set_lossy()
            shapes = [(None,) + shape for shape in model.get_input_shapes()]
            model.compute_output_shape(shapes)
            model.build(shapes)

        elif force:
            with tempfile.TemporaryDirectory() as temp_dir:
                memento = self.serializer.model_serializer.save(self._model, Path(temp_dir))
                model = self.serializer.model_serializer.restore(memento)
                memento.delete()
        else:
            model = self._model
        self._model = model
        self._model.dist_info = self.dist_info

    def reset_subprocess(self):
        self.build_model()

    def build_fp_model(self):
        self._build_fp_model()
        self._update_resolution_fp_model()

    def _build_fp_model(self):
        model = HailoModel(
            self.get_fp_hn(),
            nms_config=self._nms_config,
            optimization_target=self.optimization_target,
            lora_adapter_name=self._adapter_name,
        )
        model.dist_info = self.dist_info
        model.import_weights(self._fp_params)
        model.set_native()
        self._fp_model = model

    def _update_resolution_fp_model(self):
        algo = ResolutionReduction(
            self._fp_model,
            deepcopy(self.mo_fp_config),
            logging.DEBUG,
            logger=self._logger,
            reduction_stage=ResolutionReductionStage.apply,
        )
        algo.run()

    def get_acceleras_params(self):
        return self._acceleras_params

    def set_acceleras_params(self, value):
        self._acceleras_params = value

    def get_deequalize_params(self):
        return self._deequalize_params

    def set_deequalize_params(self, value):
        self._deequalize_params = value

    def get_fp_params(self):
        return self._fp_params

    def set_fp_params(self, value):
        self._fp_params = value

    def get_quant_params(self):
        return self._quant_params

    def set_quant_params(self, value):
        self._quant_params = value

    def get_hn(self):
        return deepcopy(self._hn)

    def set_hn(self, value):
        self._hn = value

    def get_fp_hn(self):
        return deepcopy(self._fp_hn)

    def set_fp_hn(self, val):
        self._fp_hn = val

    def set_model(self, model):
        """
        This function is a utility function to prevent model reconstruction is some cases
        """
        self._model = model

    def save_state(self, path: Optional[str] = "") -> FlowMemento:
        temp_dir = path if path else tempfile.mkdtemp()
        return self.serializer.save(temp_dir)

    def load_state(self, memento: FlowMemento, tf_safe: bool = False):
        self.serializer.restore(memento, tf_safe=tf_safe)

    @flow_control_method
    def run(self, *, memento=None, run_until=None):
        """The decorator on this method makes a lot of work"""
        step_funcs = [self.step1, self.step2, self.step3]
        for ind, step_func in enumerate(step_funcs):
            self._step_index = ind
            step_func()

    def _update_fp_data(self):
        self.set_fp_params(self._model.export_weights())
        self.set_fp_hn(self._model.export_hn())
        self.mo_fp_config = deepcopy(self._parsed_config)

    def _update_quantize_data(self, include_shared_weights=True):
        self.set_quant_params(self._model.export_hw_params(include_shared_weights=include_shared_weights))
        self.set_hn(self._model.export_hn())
        self.set_acceleras_params(self._model.export_acceleras(include_shared_weights=include_shared_weights))

    @subprocess_wrapper(DistributionStrategy.SINGLE, {DistributionStrategy.DATA_P, DistributionStrategy.MODEL_P})
    def step1(self):
        self.build_model()
        self.setup_optimization()
        self.pre_quantization_structural()
        self._update_fp_data()
        self.pre_quantization_optimization()
        self.core_quantization()
        self._update_quantize_data()

    @subprocess_wrapper(DistributionStrategy.DATA_P, {DistributionStrategy.DATA_P})
    def step2(self):
        self.build_fp_model()
        self.post_quantization_optimization()
        self._update_quantize_data()

    @subprocess_wrapper(DistributionStrategy.SINGLE, {DistributionStrategy.DATA_P, DistributionStrategy.MODEL_P})
    def step3(self):
        self.build_fp_model()
        self.finalize_optimization()
        self._model.set_lossy()
        self._update_quantize_data(include_shared_weights=False)

    def setup_optimization(self):
        self._verify_native_layers()
        self._switch_layers_by_precision()
        self._switch_decompose_16bits()
        self._switch_decompose_matmul()
        self._create_mix_precision()
        # in case of specific command for entire net quantization in 16 bit, verify that all layers support it
        if self._parsed_config.compression_params.auto_16bit_weights_ratio == 1.0:
            self._verify_layers_16_bit_support()
        self._verify_postprocess_layers()

    def pre_quantization_structural(self):
        self._use_pre_quant_weights()
        self._create_softmax_mask()
        self._quarot()
        self._layer_decompose()
        self._decompose_conv_a16_w4()
        self._conv_decomposition()
        self._remove_dead_layers()
        self._ew_add_fusing()
        self._switch_concat_with_add()
        self._decompose_channel_wise_quantization()
        self._optimize_softmax_bias()
        self._resolution_reduction()
        self._split_ew_mult_by_bit_significance()
        self._decompose_layer_norm()
        self._apu_neg_mantissa_correction()
        self._matmul_correction()
        self._correct_model()
        self._output_16bit_as_8bit()
        self._remove_resolution_reduction()

    def pre_quantization_optimization(self):
        self._resolution_reduction()
        self._collect_stats()
        self._matmul_correction_removal()
        self._global_avgpool_reduction()
        self._create_decompose_16bits()
        self._smart_softmax_stats()
        self._clip_stats_range()
        self._vectorize_scales()
        self._zero_static_channels()
        self._equalize()
        self._deequalize()

    def core_quantization(self):
        self._create_encoding()
        self._force_preact_stats()
        self._fix_zp_comp_encoding()
        self._matmul_equalization()
        self._fix_decompose_matmul()
        self._create_hw_params()

    def post_quantization_optimization(self):
        self._train_encoding()
        self._bias_correction()
        self._adaround()
        self._block_round_training()
        self._finetune()

    def get_fp_model(self):
        if self._fp_model is None:
            raise ValueError("Trying to access optimization_flow.get_fp_model() before a fp_model has been initialized")
        return self._fp_model

    def finalize_optimization(self):
        self._compute_layers_precision()
        self._noise_analysis()
        self._remove_resolution_reduction()
        self.set_hn(self._model.export_hn())
        self._dump_config()
        self._quant_checker()

    def _quant_checker(self):
        quant_checker = QuantChecker(
            self.model,
            self._parsed_config,
            logging.DEBUG,
            params_statistics=self._params_statistics,
            logger=self._logger,
        )
        quant_checker.run()
        self._finalize_algorithm(quant_checker)

    def _dump_config(self):
        """
        dump an empty file for at the end of the run only if there is a work dir
        """
        if self._work_dir is not None:
            work_dir = os.path.join(self._work_dir, "acceleras_full_flow")
            Path(work_dir).touch()

    def _fix_zp_comp_encoding(self):
        fix_zp_comp_encoding = FixZpCompEncoding(self.model, self._parsed_config, logging.INFO, logger=self._logger)
        fix_zp_comp_encoding.run()
        self._finalize_algorithm(fix_zp_comp_encoding)

    def _create_hw_params(self):
        algo = CreateHWParamsWithMatch(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)
        self.save_params(algo.name)

    def _collect_stats(self):
        algo = StatsCollector(
            self.model,
            self._parsed_config,
            logging.INFO,
            self.dataset,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)
        self.save_params(algo.name)

    def _global_avgpool_reduction(self):
        algo = GlobalAvgpoolReduction(self.model, self._parsed_config, logging.DEBUG, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _create_encoding(self):
        algo = CreateIOEncoding(self.model, self._parsed_config, logging.DEBUG, self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _apu_neg_mantissa_correction(self):
        algo = ApuNegMantissaCorrection(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _create_softmax_mask(self):
        algo = CreateSoftmaxMask(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _optimize_softmax_bias(self):
        algo = OptimizeSoftmaxBias(self.model, self._parsed_config, logging.DEBUG, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _quarot(self):
        algo = QuaRot(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _layer_decompose(self):
        algo = LayerDecompose(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _decompose_conv_a16_w4(self):
        algo = DecomposeConvA16W4(self.model, self._parsed_config, logging.DEBUG, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _remove_dead_layers(self):
        algo = DeadLayersRemoval(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _switch_concat_with_add(self):
        algo = SwitchConcatWithAdd(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _ew_add_fusing(self):
        algo = EWAddFusing(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _decompose_channel_wise_quantization(self):
        algo = DecomposeChannelWiseQuantization(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _resolution_reduction(self):
        algo = ResolutionReduction(
            self.model,
            self._parsed_config,
            logging.DEBUG,
            logger=self._logger,
            reduction_stage=ResolutionReductionStage.apply,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _switch_layers_by_precision(self):
        algo = SwitchLayersByPrecision(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _switch_decompose_16bits(self):
        algo = SwitchDecompose16Bits(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _create_decompose_16bits(self):
        algo = CreateDecompose16Bits(self.model, self._parsed_config, logging.DEBUG, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _smart_softmax_stats(self):
        algo = SmartSoftmaxStats(self.model, self._parsed_config, logging.DEBUG, self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _compute_layers_precision(self):
        # TODO: can it be removed?
        algo = MixedPrecision(self.model, self._parsed_config, logging.DEBUG, self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _clip_stats_range(self):
        algo = ClipActivationStats(self.model, self._parsed_config, logging.DEBUG, self.dataset, self._logger)
        algo.run()
        self._finalize_algorithm(algo)
        self.save_params(algo.name)

    def _vectorize_scales(self):
        """For flow control this needs to be a leaf Node"""
        self.model.vectorize_scales()

    def _force_preact_stats(self):
        algo = ForcePreactStats(self.model, self._parsed_config, logging.DEBUG, self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _zero_static_channels(self):
        algo = ZeroStaticChannelsAlgo(self.model, self._parsed_config, logging.DEBUG)
        algo.run()
        self._finalize_algorithm(algo)

    def _matmul_correction(self):
        algo = MatmulCorrection(self.model, self._parsed_config, logging.DEBUG, addition=True)
        algo.run()
        self._finalize_algorithm(algo)

    def _matmul_equalization(self):
        algo = MatmulEqualization(self.model, self._parsed_config, logging.INFO, self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _correct_model(self):
        """For flow control this should be a leaf"""
        self.model.correct_model()
        self._verify_hw_support()

    def _matmul_correction_removal(self):
        algo = MatmulCorrection(self.model, self._parsed_config, logging.DEBUG, addition=False)
        algo.run()
        self._finalize_algorithm(algo)

    def _use_pre_quant_weights(self):
        algo = UsePreQuantWeights(self.model, self._parsed_config, logging.DEBUG, addition=True)
        algo.run()
        self._finalize_algorithm(algo)

    def _conv_decomposition(self):
        algo = ConvDecomposition(self.model, self._parsed_config, logging.DEBUG, self.dataset)
        algo.run()
        self._finalize_algorithm(algo)

    def _bias_correction(self):
        model_fp = self.get_fp_model()
        algo = BiasCorrection(
            self.model,
            model_fp,
            self._parsed_config,
            self.dataset,
            work_dir=self._work_dir,
            logger_level=logging.INFO,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)
        if algo.get_results().status == AlgorithmStatus.SUCCESSFULLY_DONE:
            self.save_params(algo.name)

    def _noise_analysis(self):
        algo = HailoQuantAnalyzer(
            model=self.model,
            model_config=self._parsed_config,
            unbatched_data_set=self.dataset,
            logger=self._logger,
            native_model=self.get_fp_model(),
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _remove_resolution_reduction(self):
        algo = ResolutionReduction(
            self.model,
            self._parsed_config,
            logging.DEBUG,
            logger=self._logger,
            reduction_stage=ResolutionReductionStage.revert,
            original_input_shapes=self.original_input_shapes,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _train_encoding(self):
        model_fp = self.get_fp_model()
        qft_encoding = QftEncodingRunner(
            self.model,
            self._parsed_config,
            logging.INFO,
            model_fp,
            unbatched_train_dataset=self.dataset,
            work_dir=self._work_dir,
            logger=self._logger,
            var_freeze_cond=lambda s: "normalization" in s or "avgpool" in s,
        )
        qft_encoding.run()
        self._finalize_algorithm(qft_encoding)

    def _adaround(self):
        model_fp = self.get_fp_model()
        algo = AdaRound(
            self.model,
            model_fp,
            self._parsed_config,
            self.dataset,
            work_dir=self._work_dir,
            logger_level=logging.INFO,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _block_round_training(self):
        model_fp = self.get_fp_model()
        algo = BlockRoundTraining(
            self.model,
            model_fp,
            self._parsed_config,
            self.dataset,
            work_dir=self._work_dir,
            logging_level=logging.DEBUG,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _finetune(self):
        model_fp = self.get_fp_model()
        algo = QftRunner(
            self.model,
            self._parsed_config,
            logging.INFO,
            model_fp,
            unbatched_train_dataset=self.dataset,
            work_dir=self._work_dir,
            logger=self._logger,
            var_freeze_cond=lambda s: "normalization" in s or "avgpool" in s,
        )
        algo.run()
        self._finalize_algorithm(algo)
        if algo.get_results().status == AlgorithmStatus.SUCCESSFULLY_DONE:
            self.save_params(algo.name)

    def _deequalize(self):
        algo = Deequalize(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)
        self.set_deequalize_params(algo.deequalize_params)

    def _equalize(self):
        algo = Equalization(self.model, self._parsed_config, logging.DEBUG, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def save_params(self, dir_name):  # TODO remove this !!
        if self._work_dir is not None:
            dir_path = self._work_dir.joinpath(dir_name)
            save_acceleras_model(self.model, dir_path)
        else:
            self._logger.debug(f"Not saving {dir_name}, no work_dir passed..")

    def _verify_native_layers(self):
        for layer in self.model.layers.values():
            if isinstance(layer, BaseNativeLayer):
                raise AccelerasUnsupportedError(
                    f"Quantization of layer {layer.full_name} is currently not " "supported in acceleras",
                )

    def _create_mix_precision(self):
        algo = CreateMixedPrecision(
            model=self.model,
            model_config=self._parsed_config,
            logger_level=logging.INFO,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _add_shortcut_layer(self):
        algo = AddShortcutLayer(
            model=self.model,
            model_config=self._parsed_config,
            logger_level=logging.INFO,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _verify_layers_16_bit_support(self):
        for layer in self.model.layers.values():
            if not isinstance(layer, BaseHailoNonNNCoreLayer):
                if PrecisionMode.a16_w16 not in layer.SUPPORTED_PRECISION_MODE:
                    raise AccelerasUnsupportedError(
                        f"Quantization with 16 bit precision for layer {layer.full_name} is currently not "
                        "supported in acceleras",
                    )

    def _finalize_algorithm(self, algo: AlgorithmBase):
        self._flow_results.append(algo.get_results())
        self._params_statistics.update(algo.get_statistics())
        self._modifications_meta_data.update(algo.get_modifications_meta_data())

    def _verify_postprocess_layers(self):
        # extracts postprocess's predecessors precision mode
        hn_layers = self._hn["layers"]
        postprocess_layers = [
            layer for layer in hn_layers.keys() if hn_layers[layer]["type"] == LayerType.POSTPROCESS.value
        ]
        layers_to_verify = [
            input_layer
            for postprocess_layer in postprocess_layers
            for input_layer in hn_layers[postprocess_layer]["input"]
        ]
        precision_of_16bit = [
            layer_to_verify
            for layer_to_verify in layers_to_verify
            if self.model._acceleras_layers[layer_to_verify].get_precision_mode()
            in [PrecisionMode.a8_w8_a16, PrecisionMode.a8_w4_a16, PrecisionMode.a16_w16_a16]
        ]
        precision_of_8bit = [
            layer_to_verify
            for layer_to_verify in layers_to_verify
            if self.model._acceleras_layers[layer_to_verify].get_precision_mode()
            in [PrecisionMode.a8_w8_a8, PrecisionMode.a8_w4_a8, PrecisionMode.a16_w16_a8]
        ]

        if len(precision_of_16bit) != len(layers_to_verify) and len(precision_of_8bit) != len(layers_to_verify):
            # mixed precision in the inputs of the postprocess layer
            if len(precision_of_16bit) >= len(precision_of_8bit):
                input_layers = ", ".join(precision_of_8bit)
                log = f"Please change the precision of layers {input_layers} from 8bit to 16bit (a16w16a16)"
            else:
                input_layers = ", ".join(precision_of_16bit)
                log = f"Please change the precision of layers {input_layers} from 16bit to 8bit (a8w8a8/ a8w4a8)"
            raise AccelerasUnsupportedError(log)

    def _decompose_layer_norm(self):
        algo = DecomposeLayerNorm(self.model, self._parsed_config, logging.INFO, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _switch_decompose_matmul(self):
        algo = DecomposeMatmul(self.model, self._parsed_config, logging.INFO, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _fix_decompose_matmul(self):
        algo = DecomposeMatmulFix(self.model, self._parsed_config, logging.INFO, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _split_ew_mult_by_bit_significance(self):
        algo = SplitEWMultByBitSignificance(
            self.model,
            self._parsed_config,
            logging.DEBUG,
            self.dataset,
            logger=self._logger,
        )
        algo.run()
        self._finalize_algorithm(algo)

    def _output_16bit_as_8bit(self):
        algo = Output16BitAs8Bit(self.model, self._parsed_config, logging.DEBUG, self.dataset, logger=self._logger)
        algo.run()
        self._finalize_algorithm(algo)

    def _verify_hw_support(self):
        unsupported_layers = dict()
        for lname, layer in self.model.iterate_layers():
            layer_prec_cfg = layer.get_layer_precision_config()
            is_hw_supported = layer.is_supported_by_hw(
                self.model.optimization_target,
                LayerPrecisionConfig(**layer_prec_cfg),
            )
            if not is_hw_supported:
                unsupported_layers[lname] = layer_prec_cfg
        if unsupported_layers:
            format_unsupported = yaml.dump(unsupported_layers, allow_unicode=True, default_flow_style=False)
            self._logger.error(
                f"Unsupported layers for the target {self.model.optimization_target}: {format_unsupported}"
            )
            raise AccelerasUnsupportedError(
                "Unsupported layers for the provided optimization target. Review the log to see exact layers and configurations"
            )
