#!/usr/bin/env python

"""Full Fine Tune algorithm implementation."""

import copy
import json
import logging
import os
from collections import namedtuple

import numpy as np
import tensorflow as tf

import hailo_sdk_client.quantization.tools.optimize_kernel_ranges as okr
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import FineTuneConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_CLIP_FACTOR,
    DEFAULT_CLIP_METHOD,
    DEFAULT_CLIP_PERCENTILE,
    DEFAULT_DATASET_SIZE,
    DEFAULT_EPOCHS,
    DEFAULT_LEARNING_RATE,
    FinetunePolicy,
    LossType,
    MetaArchType,
    Optimizer,
    PrecisionMode,
    PreFTClippingMethod,
    ScheduleType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.acceleras.utils.dataset_util import rebuild_dataset_v2
from hailo_model_optimization.acceleras.utils.hn_npz_utils import NpzWrap, QNpzWrap
from hailo_model_optimization.algorithms.finetune.qft import QftRunner
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.serialization.numpy_serialization import hailo_np_savez
from hailo_sdk_common.targets.inference_targets import ParamsKinds, SdkFineTune, SdkFPOptimized

LAYER_TYPES_QUANT_WEIGHTS = [
    LayerType.conv,
    LayerType.dw,
    LayerType.dense,
    LayerType.batch_norm,
]  # , LayerType.normalization]


class FTException(Exception):
    """Fine tune algorithm exception."""


AlphaBlendConfig = namedtuple("AlphaBlendConfig", "a_decay_epochs a_decay_power")
AlphaBlendConfig.__new__.__defaults__ = (1, 3)  # default cubic (1-x**3) decay as in paper.
AlphaBlendConfig.__doc__ = "Alpha Blend configuration."
AlphaBlendConfig.a_decay_epochs.__doc__ = "Alpha Blend decay epochs."
AlphaBlendConfig.a_decay_power.__doc__ = "Alpha Blend decay power."


class Schedule:
    """
    Base class for scheduling definitions. Instances of this (and subclasses) are used to
    configure time axis behavior of learning rate, and possibly other parameters, such as the loss
    components importance factors.
    """

    def __init__(self, base_value, decay_rate=1, decay_images=1e3, warmup_images=None, warmup_value=None):
        self.base_value = base_value
        self.decay_rate = decay_rate
        self.decay_images = decay_images
        self.warmup_images = warmup_images or 0
        self.warmup_value = warmup_value or 0

    def get_val_by_step(self, global_step_t):
        return 0


class ConstSchedule(Schedule):
    """The trivial (time-independent) schedule."""

    def __init__(self, base_value):
        # note - will throw on attempt to pass more arguments, which is a good thing
        super().__init__(base_value)

    def get_val_by_step(self, global_step_t, batch_size):
        return self.base_value


class ExpSchedule(Schedule):
    """
    Exponentially decaying schedule, staircase. Starts from the base value, and reduces by
    ``decay_factor`` each ``decay_images``. It is implemented by wrapping
    ``tf.train.exponential_decay`` and adding a warmup.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_val_by_step(self, global_step_t, batch_size):
        return tf.cond(
            pred=global_step_t * batch_size < self.warmup_images,
            true_fn=lambda: self.warmup_value,
            false_fn=lambda: tf.compat.v1.train.exponential_decay(
                self.base_value,
                global_step=global_step_t - int(self.warmup_images / batch_size),
                decay_steps=int(self.decay_images / batch_size),
                decay_rate=self.decay_rate,
                staircase=True,
            ),
        )


class CosineSchedule(Schedule):
    """
    Repeated Cosine-decay schedule, implemented by wrapping ``tf.train.cosine_decay_restarts``
    and adding a warmup. The ``decay_images`` argument is used for the length of one cycle till
    restart, and the ``decay_rate`` argument is used as the factor to reduce on restart (the
    ``m_mul`` argument of ``tf.train.cosine_decay_restarts``).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_val_by_step(self, global_step_t, batch_size):
        return tf.cond(
            pred=global_step_t * batch_size < self.warmup_images,
            true_fn=lambda: self.warmup_value,
            false_fn=lambda: tf.compat.v1.train.cosine_decay_restarts(
                self.base_value,
                global_step=global_step_t - int(self.warmup_images / batch_size),
                first_decay_steps=int(self.decay_images / batch_size),
                t_mul=1.0,
                m_mul=self.decay_rate,
            ),
        )


class FineTuneConfigurator:
    """
    Use this class to control the settings of the Quantization Aware Fine-Tuning run. Desired
    configs are passed to the constructor.
    """

    def __init__(self, network_args=None, logger=None):
        """
        Initialize the object.

        Args:
            network_args (dict, optional): Args to set. If None, default values are used. The list
                of supported dictionary keys follows, including the default value/behavior in case
                the key is nonexistent (possibly omitted if trivial, e.g., False, None, [], etc.).

                * ``bias_only`` (bool): If True, train biases only yielding the baseline BFT
                  procedure from https://arxiv.org/abs/1906.03193, giving very similar results to
                  IBC.

                    Default: **False**.
                * ``layers_to_freeze`` (list of str): Don't train kernels and biases for these
                  layers; names to be given in Hailo HN notation, e.g., *conv4*, *dw3*, etc.

                    Default: **[ ]**
                * ``dataset_size`` (int): Training subset size (images per epoch). Should be smaller
                  than total images of the train data source, otherwise exception will be thrown.

                    Default: :data:`DEFAULT_DATASET_SIZE`
                * ``epochs`` (float): How many times to feed the train subset.

                    Default: :data:`DEFAULT_EPOCHS`
                * ``lr_schedule`` (subclass of :class:`Schedule`): If given, used as is for learning
                  rate schedule, rendering all LR-related args below to be "don't-care".

                    Default: built according to ``lr_schedule_type``.
                * ``lr_schedule_type`` (:class:`SCHEDULE_TYPES` enum or str): Specifies which of the
                  built-in :class:`Schedule` subclasses are invoked to build a learning rate
                  scheduling object. The parameters for this purpose will be taken from the
                  additional args below, generally used in similar ways across schedules. E.g., if
                  :attr:`SCHEDULE_TYPES.EXPONENTIAL` is passed here, an :class:`ExpSchedule`
                  instance will be built for ``lr_schedule``.

                  Default: :attr:`SCHEDULE_TYPES.COSINE_RESTARTS`
                * ``decay_epochs`` (float): Specifies number of epochs for a "step" of learning rate
                  decay. Translated to units of images and passed to :class:`Schedule` subclasses.
                  Exact usage dependent on the scheduler; see docstrings of relevant class.

                    Default: **1**
                * ``learning_rate`` (float): The base learning rate of any schedule.

                    Default: :data:`DEFAULT_LEARNING_RATE`
                * ``decay_rate`` (float): Factor for the learning rate decay per decay period. Exact
                  usage dependent on the scheduler.

                    Default: **0.5**
                * ``warmup_epochs`` (float): Time to spend with a constant (usually small)
                  rate before starting the (usually starting high and decaying) schedule.

                    Default: **1**
                * ``warmup_lr`` (float): LR to use during the warmup.

                    Default: ``learning_rate`` / 4
                * ``optimizer`` (:class:`OPTIMIZERS` enum or str): Type of optimizer, e.g., SGD,
                  Momentum.

                    Default: **Adam**
                * ``clip_method`` (:class:`PRE_FT_CLIPPING_METHODS` enum or str):  The method to use
                  for pre-training clipping of 4-bit kernels.

                    Default: :data:`DEFAULT_CLIP_METHOD`
                * ``clip_factor`` (float): Clipping factor to use if clip_method is
                  :attr:`~PRE_FT_CLIPPING_METHODS.SET_FACTOR`.

                    Default:  :data:`DEFAULT_CLIP_FACTOR`.
                * ``clip_percentile`` (float): Clipping percentile to use if clip_method is
                  :attr:`~PRE_FT_CLIPPING_METHODS.SET_PERCENTILE`.

                    Default: :data:`DEFAULT_CLIP_PERCENTILE`.
                * ``def_loss_type`` (:class:`LOSS_TYPES` enum or str): The default loss type to use
                  if ``loss_types`` is not given.

                    Default: :data:`LOSS_TYPES.L2REL`
                * ``loss_layer_names`` (list of str): Names of layers to be used for teacher-student
                  losses. Names to be given in Hailo HN notation, s.a. *conv20*, *fc1*, etc.

                    Default: the output nodes of the net (the part described by the HN)

                * ``loss_types`` (list of :class:`LOSS_TYPES`, str, or callable, of same length as *loss_layer_names*): The
                  teacher-student bivariate loss function types to apply on the native and numeric
                  outputs of the respective loss layers specified by above list. For example, "ce"
                  (standing for "cross-entropy") for the classification head(s). If callable is
                  passed, it will be used when calculating the teacher-student loss given the two
                  tensors of native and quantized nets at the given layer. This a handy generic way
                  to extend the functionality to any bivariate loss.

                    Default: the ``def_loss_type`` arg (or its default, "l2rel") for each of the
                    layers.
                * ``loss_factors`` (list of int, of same length as *loss_layer_names*): Weighting
                  factors of the above when summing the total loss.

                    Default: **1.0** for each one of the layers.
                * ``post_batch_callback`` (callable): To be called after each batch, to monitor
                  progress or trigger actions. Functionality normally added by subclassing :class:`BasicFinetuneMonitor`, and
                  passing process_batch_results of the instance.

                * ``native_layers`` (list of str) - Layers to avoid quantizing, to optimize for
                  inference with those layers in double precision. Handy for a workaround,
                  especially for sensitive layers (s.a. output layers), while the others remain in
                  need of QFT.

                    Default: **[ ]**

                * ``meta_arch`` - :class:`MetaArchType` , meta architecture of the model.
                    required for certain loss_types

        """
        """ The following keys are also supported by network_args, but they are EXPERIMENTAL:

                * ``should_quantize_activations`` (bool) - if False, don't fake-quantize activations (only weights and biases.)
                    Useful for certain debugging situations..

                    Default: **True**
                * ``alpha_blend_cfg`` (:class:`AlphaBlendConfig` namedtuple): If
                  used, use AlphaBlending (https://arxiv.org/pdf/1903.01061.pdf), instead or
                  before the standard STE-based (Straight-through Estimator aka "fake-quant")
                  training. Create that with two (optional) args, #epochs for the alpha-decay (if
                  total epochs is bigger, STE will be used after alpha decay is through), and the
                  power for the coefficient decay (i.e. 1 for linear, 3 for cubic as in the paper).
                * ``relaxed_weight_quant`` (bool): **EXPERIMENTAL**. If True, drop the STE approach
                  for 4-bit layers in favor of gradual ("relaxed") quantization approach,
                  inspired by https://arxiv.org/pdf/2004.10568.pdf .
                * ``pre_tuned_params_npz`` (str or dict accepted by load_params):
                  Use these params as a starting point instead of runner's base full-precision param
                  set. Useful for "additive" fine-tune runs with different configs (otherwise no
                  reason), first and foremost the <layers_to_freeze> config - enabling a "divide and
                  conquer" strategy, in which a few more layers are tuned (or at all quantized, in
                  RQ) on every fine-tune "round". You'll need to do a few more things externally:

                  1. Take the runner's post-tuning 'full-precision copy' after fine-tune into a
                     variable or file, so that to pass as this (i.e. pre_tuned_params_npz) argument
                     in subsequent additive fine-tune runs.
                  2. Push the 'results_by_layer' saved at the first run into the
                     "force_results_by_layer" arg of run_quantization() in subsequent runs, to
                     ensure limvals consistency across all runs.
        """

        self._logger = logger or default_logger()
        network_args = network_args if network_args else {}

        self.bias_only = network_args.get("bias_only", False)
        self.layers_to_freeze = network_args.get("layers_to_freeze", [])
        self.dataset_size = network_args.get("dataset_size", DEFAULT_DATASET_SIZE)
        self.batch_size = network_args.get("batch_size")
        self.epochs = network_args.get("epochs", DEFAULT_EPOCHS)

        # Parameters for pre-tune clipping procedure creation
        self.clip_factor = network_args.get("clip_factor", None)
        self.clip_percentile = network_args.get("clip_percentile", None)
        self.clip_method = network_args.get("clip_method", DEFAULT_CLIP_METHOD)
        if isinstance(self.clip_method, str):
            self.clip_method = PreFTClippingMethod(self.clip_method)
        if not isinstance(self.clip_method, PreFTClippingMethod):
            raise FTException("Only CLIPPING_METHODS-enum members or strings are supported as clip_method arg")
        if self.clip_method == PreFTClippingMethod.SET_FACTOR and self.clip_factor is None:
            self.clip_factor = DEFAULT_CLIP_FACTOR
        elif self.clip_method == PreFTClippingMethod.SET_PERCENTILE and self.clip_percentile is None:
            self.clip_percentile = DEFAULT_CLIP_PERCENTILE

        # Parameters for learning rate schedule creation
        self.learning_rate = network_args.get("learning_rate", DEFAULT_LEARNING_RATE)
        self.decay_epochs = network_args.get("decay_epochs", 1)
        self.decay_rate = network_args.get("decay_rate", 0.5)
        self.warmup_epochs = network_args.get("warmup_epochs", 1)
        self.warmup_lr = network_args.get("warmup_lr", self.learning_rate / 4)
        self.lr_schedule_type = network_args.get("lr_schedule_type", ScheduleType.COSINE_RESTARTS)
        self.lr_schedule = network_args.get("lr_schedule", None)

        # User can pass a complete schedule object,
        #   or it will be created according to lr_schedule and other params above
        if self.lr_schedule is None or (not isinstance(self.lr_schedule, Schedule)):
            if isinstance(self.lr_schedule_type, str):
                self.lr_schedule_type = ScheduleType(self.lr_schedule_type)
            elif not isinstance(self.lr_schedule_type, ScheduleType):
                raise FTException("Only SCHEDULE_TYPES-enum members or strings are supported as lr_schedule_type arg")
            self._create_learning_rate_schedule()

        self.optimizer = Optimizer(network_args.get("optimizer", Optimizer.adam))
        self.def_loss_type = LossType(network_args.get("def_loss_type", LossType.L2REL))
        self.bias_lr_factor = network_args.get("bias_lr_factor", 3)

        self.loss_layer_names = copy.deepcopy(network_args.get("loss_layer_names", None))
        self.loss_types = network_args.get("loss_types", None)
        if isinstance(self.loss_types, list):
            self.loss_types = list(self.loss_types)  # to avoid side-effects acting up
            for lind, loss_type in enumerate(self.loss_types):
                if callable(loss_type):
                    loss_type.__dict__["value"] = "callable"  # just for compatibility with downstream printing
                elif isinstance(loss_type, str):
                    try:
                        self.loss_types[lind] = LossType(loss_type.lower())
                    except ValueError:
                        raise FTException(f"Unsupported loss function {loss_type}")
                elif not isinstance(loss_type, LossType):
                    raise FTException(
                        "Only LOSS_TYPES-enum members or strings are supported as loss_types arg elements",
                    )

        self.do_auto_loss_factors = network_args.get("do_auto_loss_factors", False)
        self.loss_factors = copy.deepcopy(network_args.get("loss_factors", None))
        self.loss_decay_schedules = copy.deepcopy(network_args.get("loss_decay_schedules", None))

        # Support "alpha-blending" - Annealed linear combination of quantized and original kernel
        # https://arxiv.org/pdf/1903.01061.pdf . ARM, 2019
        # "Learning low-precision neural networks without Straight-Through Estimator (STE)"
        self.alpha_blend_cfg = network_args.get("alpha_blend_cfg", None)
        if self.alpha_blend_cfg is not None and not str(type(self.alpha_blend_cfg)) == str(AlphaBlendConfig):
            raise FTException("alpha_blend_cfg should be a AlphaBlendConfig named tuple ")

        self.basic_monitor = BasicFinetuneMonitor(logger)
        self.post_batch_callback = network_args.get("post_batch_callback", self.basic_monitor.process_batch_results)

        self.pre_tuned_params_npz = network_args.get("pre_tuned_params_npz")

        self.relaxed_weight_quant = network_args.get("relaxed_weight_quant", False)

        self.should_quantize_activations = network_args.get("should_quantize_activations", True)
        self.native_layers = network_args.get("native_layers", [])

        meta_arch = network_args.get("meta_arch", None)
        self.meta_arch = meta_arch if meta_arch is None else MetaArchType(meta_arch)
        self.policy = network_args.get("policy", FinetunePolicy.enabled)

        self._logger.debug("Constructed a FineTuneConfigurator")

    def to_config(self):
        if self.bias_lr_factor != 3:
            raise AccelerasImplementationError("Finetune - bias_lr_factor is not supported")
        if self.do_auto_loss_factors:
            raise AccelerasImplementationError("Finetune - do_auto_loss_factors is not supported")
        if self.alpha_blend_cfg is not None:
            raise AccelerasImplementationError("Finetune - alpha_blend_cfg is not supported")
        if self.pre_tuned_params_npz is not None:
            raise AccelerasImplementationError("Finetune - pre_tuned_params_npz is not supported")
        if self.relaxed_weight_quant:
            raise AccelerasImplementationError("Finetune - relaxed_weight_quant is not supported")
        if not self.should_quantize_activations:
            raise AccelerasImplementationError("Finetune - should_quantize_activations == False is not supported")

        def filter_new_config(key):
            if key.startswith("_"):
                return False
            if key not in FineTuneConfig.__fields__:
                return False
            return True

        config_keys = filter(filter_new_config, self.__dict__.keys())
        export_dict = {key: self.__dict__[key] for key in config_keys}
        return FineTuneConfig(**export_dict)

    def _get_auto_loss_factors(self, layers_to_4bit, factor=3):
        """Set the loss ratio b/w chosen loss layers such that 4bit layers get X factor weight"""
        return [factor if layer in layers_to_4bit else 1 for layer in self.loss_layer_names]

    def _create_learning_rate_schedule(self):
        """Override in subclass to make your own schedule type usable via the simple API."""
        if self.lr_schedule_type == ScheduleType.COSINE_RESTARTS:
            self.lr_schedule = CosineSchedule(
                self.learning_rate,
                decay_rate=self.decay_rate,
                decay_images=int(self.decay_epochs * self.dataset_size),
                warmup_images=int(self.warmup_epochs * self.dataset_size),
                warmup_value=self.warmup_lr,
            )
        elif self.lr_schedule_type == ScheduleType.EXPONENTIAL:
            self.lr_schedule = ExpSchedule(
                self.learning_rate,
                decay_rate=self.decay_rate,
                decay_images=int(self.decay_epochs * self.dataset_size),
                warmup_images=int(self.warmup_epochs * self.dataset_size),
                warmup_value=self.warmup_lr,
            )
        elif self.lr_schedule_type == ScheduleType.CONSTANT:
            self.lr_schedule = ConstSchedule(self.learning_rate)
        else:
            assert 0  # should never get here if the switch-case above covers all enum options

    def _get_losses_and_wtotal(
        self,
        sdk_export_ft,
        sdk_export_native,
        global_step,
        batch_size,
        layernames=None,
        losstypes=None,
        factors=None,
        schedules=None,
    ):
        """Get a vector of losses and a weighted total loss."""
        layernames = layernames or self.loss_layer_names
        losstypes = losstypes or self.loss_types
        factors = factors or self.loss_factors
        schedules = schedules or self.loss_decay_schedules

        if not layernames:
            nat_tensors, ft_tensors = sdk_export_native.ft_train_output_tensors, sdk_export_ft.ft_train_output_tensors
        else:
            nat_tensors = []
            for lname in layernames:
                possible_tensors = self._get_tensors_by_layer_name(sdk_export_native.all_layers, lname)
                if len(possible_tensors) > 1:
                    self._logger.debug(f"Found multiple tensors for layer {lname}")
                nat_tensors.append(possible_tensors[0])
            ft_tensors = []
            for lname in layernames:
                possible_tensors = self._get_tensors_by_layer_name(sdk_export_ft.all_layers, lname)
                if len(possible_tensors) > 1:
                    self._logger.debug(f"Found multiple tensors for layer {lname}")
                ft_tensors.append(possible_tensors[0])

        self._logger.debug("Using the following tensors for the teacher-student losses: " + str(nat_tensors))

        losstypes = losstypes or [self.def_loss_type] * len(nat_tensors)
        factors = factors or [1] * len(nat_tensors)

        if not (len(ft_tensors) == len(losstypes) == len(factors) == len(nat_tensors)):
            raise FTException("Loss spec lists are not of same length")

        losses = [
            self._get_single_ts_loss(nat_tsr, ft_tsr, lt)
            for nat_tsr, ft_tsr, lt in zip(nat_tensors, ft_tensors, losstypes)
        ]

        if not schedules:
            total_loss = sum([factor * loss for (factor, loss) in zip(factors, losses)])
        else:
            if not (len(schedules) == len(factors) and all(isinstance(sch, Schedule) for sch in schedules)):
                raise FTException("The 'schedules' list is invalid")
            total_loss = sum(
                [
                    factor * loss * schedule.get_val_by_step(global_step, batch_size)
                    for (factor, loss, schedule) in zip(factors, losses, schedules)
                ],
            )

        losses_d = {
            "_".join(nat_tsr.name.split("/")[1:] + [lt.value]): loss
            for nat_tsr, lt, loss in zip(nat_tensors, losstypes, losses)
        }

        return total_loss, losses_d

    @staticmethod
    def _get_tensors_by_layer_name(tensors, layer_name):
        wanted_tensors = []
        scope_name, base_name = layer_name.split("/")
        for t in tensors:
            tensor_scope, tensor_layer, _ = t.name.split("/", 2)
            if tensor_scope.startswith(scope_name) and base_name == tensor_layer:
                wanted_tensors.append(t)
        return wanted_tensors

    def _get_single_ts_loss(self, native_res, numeric_res, loss_type):
        """
        Calculate teacher-student loss for a specific layer. Override in subclass to use your own
        loss type.
        """
        if callable(loss_type):  # enabling custom user-supplied KD/TS loss functions...
            return loss_type(native_res, numeric_res)

        elif loss_type == LossType.CROSS_ENTROPY:
            # assumes that this is the pre-softmax layer...
            temperature = 1.5
            ce_loss = tf.nn.softmax_cross_entropy_with_logits(
                labels=tf.nn.softmax(native_res / temperature),
                logits=numeric_res / temperature,
            )
            return tf.reduce_mean(input_tensor=ce_loss)

        elif loss_type == LossType.L2:
            diffnorm = tf.square(native_res - numeric_res)
            return tf.sqrt(tf.reduce_mean(input_tensor=diffnorm))

        elif loss_type == LossType.L2REL:
            diffnorm = tf.reduce_mean(input_tensor=tf.square(native_res - numeric_res))
            native_norm = tf.reduce_mean(input_tensor=tf.square(native_res))
            return tf.sqrt(diffnorm / native_norm)
        elif loss_type == LossType.L2REL_CHW:
            self.out_channel_factors = self._get_out_channel_factors()
            return self.l2rel_chw_ts(native_res, numeric_res)
        elif loss_type == LossType.COSINE:
            return self.cosine_ts(native_res, numeric_res)

        else:
            assert 0  # should never get here if the switch-case above covers all enum options

    def _get_nodes_to_train(self, graph, param_names, sdk_export_native, sdk_export_ft):
        """
        Get nodes to train using fine tune.

        Returns
            dict: Keys are the names of the base variables (weights and biases) to modify, and
            values are the auxiliary delta variables (Tensorflow tensors) to train.

        """
        # Assert that all exports are aligned
        assert len(sdk_export_native.biases) == len(sdk_export_ft.biases_delta) and len(
            sdk_export_native.kernels,
        ) == len(sdk_export_ft.kernels_delta), "Native and fine tune deltas export lengths are not aligned."
        for i in range(len(sdk_export_native.biases)):
            assert (
                sdk_export_native.biases[i].name.split("/")[1] == sdk_export_ft.biases_delta[i].name.split("/")[1]
                or sdk_export_native.biases[i].name.split("/")[1] == sdk_export_ft.biases_delta[i].name.split("/")[2]
            ), "Biases in native export and biases_delta in fine tune export are not aligned."
        for i in range(len(sdk_export_native.kernels)):
            assert (
                sdk_export_native.kernels[i].name.split("/")[1] == sdk_export_ft.kernels_delta[i].name.split("/")[1]
                or sdk_export_native.kernels[i].name.split("/")[1] == sdk_export_ft.kernels_delta[i].name.split("/")[2]
            ), "Kernels in native export and kernels_delta in fine tune export are not aligned."

        param_name_to_deltavar_d = dict(zip([k.name for k in sdk_export_native.biases], sdk_export_ft.biases_delta))

        if not self.bias_only:
            param_name_to_deltavar_d.update(
                zip([b.name for b in sdk_export_native.kernels], sdk_export_ft.kernels_delta),
            )

        return {
            k: v
            for k, v in param_name_to_deltavar_d.items()
            if k in param_names and k.split("/")[1] not in self.layers_to_freeze
        }

    def _get_aux_losses(self, sdk_export_ft, global_step_f, batches_per_epoch, hn_layer_names_4bit):
        """
        Create "auxiliary" / regularization losses, e.g.  (current use) penalizing distance from
        grid, for the gradual (aka "relaxed") quantization approach as a STE alternative. Currently
        limited to 4-bit layers, only weights (not biases), and the rounding part of quantization
        (clipping is done forcibly).
        """
        reg_losses = {}
        tot_reg_loss = 0
        if self.relaxed_weight_quant:
            # The warmup is possibly used for alpha-smoothing of the clipping, we delay the RQ till warmup is over
            steps_after_ini_delay = global_step_f - self.warmup_epochs * batches_per_epoch
            # For the last 2 'cycles' (each of <decay_epochs> duration) we perform 'accelerated rounding',
            #  a very strongly growing off-grid penalty, so that to make sure all weights are on-grid before tuning is through,
            #  so biases are tuned to that, and overal train-test consistency is maintained
            epochs_till_accelerated_rounding = (
                self.epochs - self.warmup_epochs - 2 * self.decay_epochs
            )  # e.g. 22 - 2 - 2*2 = 16
            steps_post_acc_round = steps_after_ini_delay - batches_per_epoch * epochs_till_accelerated_rounding
            accerelated_rounding_lambda = 0.1 * steps_post_acc_round

            for lname in hn_layer_names_4bit:
                # Get the 'fractional parts' tensor from export. Note that it got same shape as kernel.
                q_mod_cl_frac = sdk_export_ft.ft_kern_frac_part_tensors[lname]

                # Computing the penalty for the fractional part between 0 and 1 -
                #   a parabola peaking at 0.5. This corresponds to beta=2 in https://arxiv.org/pdf/2004.10568.pdf
                f_reg_unscaled_array = 1 - tf.square(2 * q_mod_cl_frac - 1)

                # Aggregation (note the sqrt(size) as a compromise between "sum" and "mean"):
                f_reg_unscaled_mean = tf.reduce_mean(input_tensor=f_reg_unscaled_array)
                f_reg_aggregated = f_reg_unscaled_mean * tf.sqrt(
                    tf.cast(tf.size(input=f_reg_unscaled_array), tf.float32),
                )

                # --
                # Here we diverge from https://arxiv.org/pdf/2004.10568.pdf ;
                #   instead of using heuristically chosen increase-schedule of the 'rounding-drive-strength'
                #   (denoted as 'lambda' in that paper), which we interpret as 'open-loop', we do 'closed-loop' approach,
                #   trying to get the regularization loss itself to follow a pre-determined decay-schedule.
                #
                # Creating the decay schedule, currently a parabolic decay,
                #    over the duration between warmup and accelerated rounding
                penalty_decay_ratio = steps_after_ini_delay / (batches_per_epoch * epochs_till_accelerated_rounding)
                planned_freg_decay_lin = 1 - tf.nn.relu(penalty_decay_ratio)
                planned_freg_decay = tf.square(tf.nn.relu(planned_freg_decay_lin))
                #
                # We look at how much f_reg decay lags behind the planned decay, as the 'feedback signal':
                f_reg_diff_from_planned = f_reg_unscaled_mean - planned_freg_decay
                # Strength and tolerance constants of the "control loop" -
                # are empirically chosen so that f_reg decay closely follows the guidance of the planned decay -
                # - but without too much oscillation (even in presence of cosine restarts)
                soft_guidance_tolerance = 0.01
                soft_guidance_strength = 5
                # Soft constraint enforcement with sigmoid - penalty rises from 0 to high value when 'lag' is positive, but not abruptly
                soft_guidance_lambda = soft_guidance_strength * tf.nn.sigmoid(
                    f_reg_diff_from_planned / soft_guidance_tolerance,
                )
                # --

                # Switch from 'soft guidance' to accelerated rounding'
                freg_lambda = tf.cond(
                    pred=steps_post_acc_round > 0,
                    true_fn=lambda: accerelated_rounding_lambda,
                    false_fn=lambda: soft_guidance_lambda,
                )
                f_reg = freg_lambda * f_reg_aggregated

                tot_reg_loss += f_reg
                reg_losses.update({f"log10_freg_mean_{lname}": tf.math.log(f_reg_unscaled_mean) / np.log(10)})

        return tot_reg_loss, reg_losses

    def _build_optimizer(self, var_list, global_step, batch_size, total_loss):
        """
        Complete the training graph by building on top of the twin neural part -- the Losses and
        then the Optimizer.
        """
        learning_rate_decay = self.lr_schedule.get_val_by_step(global_step, batch_size)

        if self.optimizer == Optimizer.adam:
            opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate_decay, epsilon=0.01)
            # note: large epsilon improves numerical stability and test repeatability..
        elif self.optimizer == Optimizer.sgd:
            opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate_decay)
        elif self.optimizer == Optimizer.momentum:
            opt = tf.compat.v1.train.MomentumOptimizer(learning_rate=learning_rate_decay, momentum=0.9)
        elif self.optimizer == Optimizer.rmsprop:
            opt = tf.compat.v1.train.RMSPropOptimizer(learning_rate=learning_rate_decay)
        else:
            raise FTException("Unsupported optimizer")

        grads_and_vars = opt.compute_gradients(total_loss, var_list=list(var_list))

        grads_and_vars_modified = [
            (grad * (self.bias_lr_factor if "bias" in var.name else 1), var)
            for grad, var in grads_and_vars
            if grad is not None
        ]
        optimizer_op = opt.apply_gradients(grads_and_vars_modified, global_step=global_step)

        return optimizer_op, opt

    def clip_kernel_pre_ft(self, kernel, lname, bits, num_groups=1, verbose=True):
        """
        a middle-man to facilitate extensions by subclassing FineTuneConfigurator
        """
        return okr.clip_kernel_pre_ft(
            kernel,
            lname,
            bits,
            verbose=verbose,
            num_groups=num_groups,
            clip_method=self.clip_method,
            clip_factor=self.clip_factor,
            clip_percentile=self.clip_percentile,
        )

    def _analyze_compression(
        self,
        params_orig,
        hn_model,
        hn_layers_quant_weights,
        hn_layer_names_4bit,
        clip_ranges=None,
    ):
        """Print some info on how memory we saved by reducing some layers to 4-bit."""
        mem8 = 1e-20
        mem4 = 0
        mem_zeros = 0
        weight_numbers = {}
        try:
            for lind, layer in enumerate(hn_layers_quant_weights):
                lname = layer.name
                is_4b = lname in hn_layer_names_4bit
                lparam_key = layer.name
                k = params_orig[lparam_key].kernel
                if len(k.shape) < 4:
                    continue
                mem = k.shape[0] * k.shape[1] * k.shape[2] * k.shape[3]
                weight_numbers[lname] = mem
                mem8 += mem
                mem4 += mem / 2 if is_4b else mem
                clip_range = (
                    clip_ranges[lind]
                    if clip_ranges is not None
                    else self.clip_kernel_pre_ft(k, lname, verbose=False, bits=4)
                )
                zero_weights = np.sum(np.abs(k) < clip_range / 7 / 2) if is_4b else 0
                mem_zeros += zero_weights
            self._logger.debug(f"Reduced weights memory to {100 * mem4 / mem8}%% of original!")
            self._logger.debug(f"Total proportion of zeros: {100 * mem_zeros / mem8}%")
            return weight_numbers
        except Exception:
            pass

    @classmethod
    def from_json(cls, path, logger=None):
        with open(path) as f:
            config = json.load(f)

        return cls(network_args=config, logger=logger)

    def info_config(self):
        for x, y in self.__dict__.items():
            if not x.startswith("_"):
                self._logger.verbose(f"[{type(self).__name__}] {x}: {y}")

    def cosine_ts(self, native_res, numeric_res):
        native_res_normed = tf.nn.l2_normalize(native_res, axis=-1)
        numeric_res_normed = tf.nn.l2_normalize(numeric_res, axis=-1)
        cosine_loss = tf.compat.v1.losses.cosine_distance(native_res_normed, numeric_res_normed, axis=-1)
        return tf.reduce_mean(input_tensor=cosine_loss)

    def l2rel_chw_ts(self, native_res, numeric_res, channel_factors=None):
        channel_factors = channel_factors or self.out_channel_factors
        native_res_chw = native_res * channel_factors
        numeric_res_chw = numeric_res * channel_factors
        diffnorm = tf.reduce_mean(input_tensor=tf.square(native_res_chw - numeric_res_chw))
        native_norm = tf.reduce_mean(input_tensor=tf.square(native_res_chw))
        return tf.sqrt(diffnorm / native_norm)

    def _get_out_channel_factors(self, settings=None):
        settings = settings or {"localization_boost": 10}
        if self.meta_arch == MetaArchType.yolo:
            loc_boost = settings["localization_boost"]
            yoloboostloc = np.ones(
                255,
            )
            yoloboostloc[:5] = loc_boost
            yoloboostloc[85:90] = loc_boost
            yoloboostloc[170:175] = loc_boost
            return yoloboostloc
        else:
            raise NotImplementedError(f"Can't _get_out_channel_factors with meta_arch={self.meta_arch.value}")


class BasicFinetuneMonitor:
    """
    Provides a basic, batch-by-batch monitoring callback. It can be subclassed for additional
    functionality (e.g., for Tensorboard-style real-time plotting).
    """

    def __init__(self, logger=None, loss_log_period=32, loss_ma_coeff=0.98):
        """
        Args:
            loss_log_period (int): Print every so and so batches.
            loss_ma_coeff (float): How much to smooth the "moving average".

        """
        self._batch_ind = 0
        self._ma_losses_d = {}
        self._losses_vals_history_d = {}
        self._ma_losses_vals_history_d = {}
        self._loss_ma_coeff = loss_ma_coeff
        self._loss_log_period = loss_log_period
        self._logger = logger or default_logger()

    def process_batch_results(
        self,
        net_out_native_vals,
        net_out_ft_vals,
        image_info_val,
        losses_vals_d,
        reg_losses_vals_d=None,
        **kwargs,
    ):
        if self._batch_ind == 0:
            self._ma_losses_d = losses_vals_d
            self._losses_vals_history_d = {k: [v] for k, v in losses_vals_d.items()}
            self._ma_losses_vals_history_d = {k: [v] for k, v in self._ma_losses_d.items()}
        else:
            for k in losses_vals_d:
                self._ma_losses_d[k] = (
                    self._loss_ma_coeff * self._ma_losses_d[k] + (1 - self._loss_ma_coeff) * losses_vals_d[k]
                )
                self._losses_vals_history_d[k].append(losses_vals_d[k])
                self._ma_losses_vals_history_d[k].append(self._ma_losses_d[k])
        if self._batch_ind % self._loss_log_period == 0:
            loss_rep_str = ", ".join([f"{k}={v:.3f}" for k, v in self._ma_losses_d.items()])
            self._logger.debug(f"MA loss components at batch {self._batch_ind}: " + loss_rep_str)

        self._batch_ind += 1


def get_loss_factors(ft_cfg, layer_names_to_4bit):
    """
    Set the loss factors automatically for the loss layers in case not already configured
    """
    loss_factors = None
    if ft_cfg.do_auto_loss_factors and ft_cfg.loss_factors is None:
        loss_factors = ft_cfg._get_auto_loss_factors(layer_names_to_4bit)
        default_logger().debug(f"Setting loss factors of {loss_factors} to loss layers {ft_cfg.loss_layer_names}")
    return loss_factors


def fine_tune_from_feed(
    runner,
    train_dataset,
    ft_cfg: FineTuneConfig,
    results_by_layer,
    batch_size,
    work_dir=None,
    use_gpu=True,
):
    """
    Improve the quantized model's accuracy by Quantization-Aware Finetuning (QFT). This function
    uses an existing client runner and modifies it. This function is called internally by
    :func:`~hailo_sdk_client.quantization.quantize.run_quantization`, no use case for user-side
    invocation yet.

    Args:
        runner (:class:`~hailo_sdk_client.sdk_backend.sdk_backend.SDKBackend`): SDK backend to
            work with. It should already contain the translated parameters to correct. The corrected
            weights are loaded into this runner.
        train_dataset: Dataset object for finetune training, which is a
           :obj:`tf.data.Dataset` object. Required argument, has to be unbatched dataset
        ft_cfg (:class:`FineTuneConfigurator`): Fine tune configuration.
        results_by_layer (dict): Use these data for quantization statistics instead of collecting
           new calibration set.
        batch_size: batch size for dataset during train
        work_dir: Optional, path to dump debug files
        use_gpu: boolean, indicates whether to use gpu for the training or not. Should always be True when possible

    """
    ft_cfg = FineTuneConfigurator(ft_cfg.dict())
    params_orig = runner.get_params_fp_optimized()

    batched_dataset = train_dataset.batch(batch_size)

    hn_model = runner.model

    hn_layers_4b_weights = [
        layer
        for layer in hn_model.stable_toposort()
        if layer.precision_config.precision_mode.reduce() == PrecisionMode.a8_w4
    ]
    hn_layer_names_4b_weights = [ll.name for ll in hn_layers_4b_weights]

    batches_per_epoch = int(ft_cfg.dataset_size / batch_size)

    sess_config = None if use_gpu else tf.compat.v1.ConfigProto(device_count={"GPU": 0})
    session = tf.compat.v1.Session(config=sess_config, graph=tf.Graph())
    with session.graph.as_default() as graph, session.as_default(), np.printoptions(precision=3):
        global_step = tf.Variable(0, trainable=False)
        global_step_f = tf.cast(global_step, tf.float32)
        dataset_v1 = rebuild_dataset_v2(batched_dataset)
        iterator = tf.compat.v1.data.make_initializable_iterator(dataset_v1)
        preprocessed_data, image_info = iterator.get_next()
        sdk_export_native = runner._get_tf_graph(SdkFPOptimized(), custom_session=session, nodes=preprocessed_data)

        # - enable continued training from a saved state!
        #   note we do it after creation of 'native' sdk_export but BEFORE 'finetune' sdk_export twin,
        #   to have only the latter leg of twin graph get the pre-tuned params..
        if ft_cfg.pre_tuned_params_npz is not None:
            runner.load_params(ft_cfg.pre_tuned_params_npz, params_kind=ParamsKinds.FP_OPTIMIZED)
            params_orig = runner.get_params_fp_optimized()

        target = SdkFineTune()

        target.set_fine_tune_params(
            should_quantize_activations=ft_cfg.should_quantize_activations,
            should_quantize_weights=True,
            should_relax_weights=ft_cfg.relaxed_weight_quant,
        )

        sdk_export_ft = runner._get_tf_graph(
            target,
            twin_mode=True,
            custom_session=sdk_export_native.session,
            native_layers=ft_cfg.native_layers,
            nodes=preprocessed_data,
        )

        param_name_to_deltavar_d = ft_cfg._get_nodes_to_train(
            graph,
            params_orig.keys(),
            sdk_export_native,
            sdk_export_ft,
        )

        # For alpha-blend implementation, we connect to the tensors created on backend,
        #   and feed them with other values. In this case, use 1 instead of 0 (no-alpha) and anneal to 1 while training.
        fd = {}
        if ft_cfg.alpha_blend_cfg is not None:
            alpha_keys = []
            for lname in hn_layer_names_4b_weights:  # TODO what about alpha-blend for 8-bit quant?
                alpha_t = sdk_export_ft.ft_alpha_tensors[lname]
                alpha_keys.append(alpha_t)
                fd[alpha_t] = 0.99

        loss_factors = get_loss_factors(ft_cfg, hn_layer_names_4b_weights)
        total_loss, losses_d = ft_cfg._get_losses_and_wtotal(
            sdk_export_ft,
            sdk_export_native,
            global_step,
            batch_size,
            factors=loss_factors,
        )

        tot_reg_loss, reg_losses = ft_cfg._get_aux_losses(
            sdk_export_ft,
            global_step_f,
            batches_per_epoch,
            hn_layer_names_4b_weights,
        )
        total_loss += tot_reg_loss

        optimizer_op, opt = ft_cfg._build_optimizer(
            list(param_name_to_deltavar_d.values()),
            global_step,
            batch_size,
            total_loss,
        )

        with sdk_export_native.session.as_default() as sess, sdk_export_native.graph.as_default():
            sess.run([v.initializer for v in sdk_export_ft.kernels_delta + sdk_export_ft.biases_delta])
            sess.run([v.initializer for v in opt.variables()])
            sess.run([tf.compat.v1.local_variables_initializer(), global_step.initializer])
            default_logger().debug("Starting Fine-tune run...")

    try:
        for epoch in range(int(ft_cfg.epochs)):
            sess.run(iterator.initializer)
            for batch_ind in range(batches_per_epoch):
                full_batch_ind = batches_per_epoch * epoch + batch_ind

                # Anneal the alpha values (if enabled)
                if ft_cfg.alpha_blend_cfg is not None:
                    adecay_part = full_batch_ind / (ft_cfg.alpha_blend_cfg.a_decay_epochs * batches_per_epoch)
                    alpha_val = max(0, (1 - adecay_part)) ** ft_cfg.alpha_blend_cfg.a_decay_power
                    for k in alpha_keys:
                        fd[k] = alpha_val

                # (!) Here the actual forward-backward graph eval happens..
                with sdk_export_native.session.as_default() as sess, sdk_export_native.graph.as_default():
                    (
                        net_out_native_vals,
                        net_out_ft_vals,
                        image_info_val,
                        total_loss_val,
                        _opto,
                        step,
                        losses_vals_d,
                        reg_losses_vals_d,
                    ) = sess.run(
                        [
                            sdk_export_native.output_tensors,
                            sdk_export_ft.output_tensors,
                            image_info,
                            total_loss,
                            optimizer_op,
                            global_step,
                            losses_d,
                            reg_losses,
                        ],
                        feed_dict=fd,
                    )

                ft_cfg.post_batch_callback(
                    net_out_native_vals,
                    net_out_ft_vals,
                    image_info_val,
                    losses_vals_d,
                    reg_losses_vals_d=reg_losses_vals_d,
                )

            default_logger().info(f"Finished {epoch + 1} epochs")

    except KeyboardInterrupt:
        default_logger().debug(
            f"Forced to stop after {epoch} epochs and {step} steps," "will try to proceed with what we got thus far",
        )

    except tf.errors.OutOfRangeError:
        default_logger().warning(
            "Out of data - check your dataset_size argument vs. your data source (e.g tfrecord) size",
        )
    except tf.errors.InvalidArgumentError as err:
        if "Current libxsmm and customized CPU implementations do not yet support dilation rates larger than 1" in str(
            err,
        ):
            raise FTException("dilation rates larger than 1 is not support on CPU, please use GPU instead")
        raise

    finally:
        # ..we're thru with training and into the crucial step of train-to-inference translations.
        #  We need to either exactly replicate the backend transforms (e.g. kernel += kernel_delta),
        #  or let backend help us by defining the "final kernel as used" which we'll now get the value of..
        final_kernel_tensors = sdk_export_ft.ft_final_kernel_tensors
        with sdk_export_native.session.as_default() as sess, sdk_export_native.graph.as_default():
            final_kernel_values = sess.run(final_kernel_tensors, feed_dict=fd)
            varname_to_delta_value = sess.run(param_name_to_deltavar_d, feed_dict=fd)
    if work_dir is not None:
        hailo_np_savez(os.path.join(work_dir, "ma_loss_history.npz"), **ft_cfg.basic_monitor._ma_losses_vals_history_d)
        hailo_np_savez(os.path.join(work_dir, "loss_history.npz"), **ft_cfg.basic_monitor._losses_vals_history_d)

    params_modified = copy.deepcopy(params_orig)

    for param, delta in varname_to_delta_value.items():
        lname = "/".join(param.split("/")[:2])
        if "kernel" in param and lname in final_kernel_tensors:
            params_modified[param] = np.reshape(final_kernel_values[lname], params_modified[param].shape)
        else:
            params_modified[param] += delta.astype(params_modified[param].dtype)

    runner.load_params(params_modified, ParamsKinds.FP_OPTIMIZED)

    runner.translate_params_by_inference_results(results_by_layer)

    # (!) Note the runner's state at this point, to inform any subsequent external manipulations:
    #
    # (A) the "full-precision param set" has got the tuned params (by virtue of load_params(params_modified) above)
    #   which are actually grid-pinned because of "final_kernel_tensors" mechanism
    #
    # (B) the quantized params set has got a fully Hailo-quantized param set based on tuning results,
    #      and having reused the 'results_by_layer' to ensure full train-test consistency of limvals etc.

    default_logger().debug("QFT Done")


def prep_imagenet(images, image_info, label_offset=0):
    # TODO: The train label is off by one, it needs to be fixed
    return images, tf.one_hot(image_info["label_index"] + label_offset, depth=1000 + label_offset)


def prep_label_less(images, image_info):
    # TODO: The train label is off by one, it needs to be fixed
    return images, 0


class DummyScope:
    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass


def _verify_ft_cfg_acceleras(ft_cfg: FineTuneConfig):
    if ft_cfg.bias_only:
        raise AccelerasImplementationError("Finetune - bias only is not supported")
    if ft_cfg.layers_to_freeze:
        raise AccelerasImplementationError("Finetune - layers_to_freeze are not supported")
    if ft_cfg.lr_schedule_type != ScheduleType.COSINE_RESTARTS:
        raise AccelerasImplementationError(f"Finetune - ScheduleType {ft_cfg.lr_schedule_type} is not supported")
    if ft_cfg.optimizer != Optimizer.adam:
        raise AccelerasImplementationError(f"Finetune - Optimizer {ft_cfg.optimizer} is not supported")
    if ft_cfg.native_layers:
        raise AccelerasImplementationError("Finetune - native_layers are not supported")


def acceleras_fine_tune(
    backend,
    calib_dataset,
    ft_cfg: FineTuneConfig,
    model_config: ModelOptimizationConfig,
    distributed=False,
):
    """
    translating QFT config to acceleras format from the legacy ft_cfg,
    and from stuff passed separately (e.g. train_batch_size, which was implicit in legacy code).
    """
    # TODO FInd a better way to pass Batch size 3.14 https://hailotech.atlassian.net/browse/SDK-23261

    model_name = backend.model_name
    hn = backend.model.to_hn(model_name, json_dump=False)
    npz = backend.get_params_fp_optimized().params
    qnpz = backend.get_params_translated().params

    _verify_ft_cfg_acceleras(ft_cfg)

    if distributed:
        strategy = tf.distribute.MirroredStrategy()
        strategy_scope = strategy.scope
    else:
        strategy_scope = DummyScope

    with strategy_scope():
        model_fp = HailoModel(hn)
        model_fp.import_config(model_config)
        model_fp.import_weights(npz)
        model_fp.import_hw_params_from_qnpz(qnpz, force_legacy=True)

        from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
        from hailo_model_optimization.acceleras.hailo_layers.hailo_deconv import HailoDeconv

        model_quant = HailoModel(hn)
        model_quant.import_config(model_config)
        quantized_layers = [layer.name for layer in model_quant.layers.values() if isinstance(layer, BaseHailoConv)]

        model_quant.import_weights(npz)
        model_quant.import_hw_params_from_qnpz(qnpz, force_legacy=True)
        model_quant.set_quantized(quantized_layers)

        for layer in model_quant.layers.values():
            if not isinstance(layer, BaseHailoConv) or isinstance(layer, HailoDeconv):
                for op in layer.atomic_ops:
                    op.fully_native = True
                if len(layer.trainable_variables) > 0:
                    backend._logger.error(f"{layer.name} has trainable variables: {layer.trainable_variables}")
        qftr = QftRunner(
            model_quant,
            model_config,
            logging.INFO,
            model_fp,
            calib_dataset,
            unbatched_eval_dataset=None,
            var_freeze_cond=lambda s: "normalization" in s or "avgpool" in s,
        )  # TODO try remove this

        qftr.run()

        # TODO - consider removing this,
        #  there's now a proper eagerization in run_qft itself (by calling infer encodings)
        imagenet_train_ds = calib_dataset.map(prep_label_less)
        sample = next(iter(imagenet_train_ds.batch(1)))[0]
        if isinstance(sample, dict):
            sample = {key: val.numpy() for key, val in sample.items()}
            _ = model_quant(sample)
        else:
            _ = model_quant(sample.numpy())  # little hack to eagerize after operations in graph form..

        model_quant.set_quantized()
        ft_qnpz = QNpzWrap(dict(qnpz).copy())
        ft_qnpz.write_all_kernels_biases(model_quant)
        ft_qnp = NpzWrap(dict(npz).copy())
        ft_qnp.write_all_kernels_biases(model_quant)
        backend.load_params(ModelParams(ft_qnpz.params), params_kind=ParamsKinds.TRANSLATED)
        backend.load_params(ModelParams(ft_qnp.params), params_kind=ParamsKinds.FP_OPTIMIZED)
