import os
from datetime import datetime
from itertools import combinations

import networkx as nx
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.element_wise_add_op import ElementwiseAddDirectOp, ElementwiseAddOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantTrainMode
from hailo_model_optimization.acceleras.model import distiller
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_BATCH_SIZE,
    DEFAULT_LEARNING_RATE,
    FinetunePolicy,
    PrecisionMode,
    QFTWriterMode,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasResourceError,
    AccelerasUnsupportedError,
)
from hailo_model_optimization.acceleras.utils.dataset_util import verify_dataset_size
from hailo_model_optimization.algorithms.bias_correction.bias_correction_v2 import BiasCorrection
from hailo_model_optimization.algorithms.finetune.build_scale_variables import BuildScaleVariables
from hailo_model_optimization.algorithms.finetune.deep_distill_loss import DeepDistillLoss
from hailo_model_optimization.algorithms.finetune.optimizers import MultiOptimizer, get_optimizer_gen
from hailo_model_optimization.algorithms.finetune.schedulers import (
    QFTScheduler,
    get_train_scheduler,
    get_warmup_scheduler,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class FullHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_logs = []

    def on_batch_end(self, batch, logs):
        # self.per_batch_losses.append(logs.get("loss"))
        self.per_batch_logs.append(logs)


class TensorboardWriter(tf.keras.callbacks.Callback):
    def __init__(self, distiller, name, workdir=None):
        self._distiller = distiller
        self._workdir = workdir if workdir is not None else "."
        self.name = name
        super().__init__()

    def on_train_begin(self, logs):
        timefmt = datetime.strftime(datetime.now(), "%Y%m%d_%H%M")
        self._distiller.writer = tf.summary.create_file_writer(
            os.path.join(self._workdir, "train_log", self.name, timefmt)
        )


def add_label_to_data(images, image_info):
    # TODO: The train label is off by one, it needs to be fixed
    return images, 0


def _train_encoding_freeze(quant_model, name):
    if "source_" in name:
        real_name = next(quant_model.model_encoding.flow.successors(name.split(":")[0]))
        type_ = real_name.split("/")[-1].split(":")[0]
        return type_ in {
            "kernel_zero_point",
            "input_zero_point",
            "input_scale",
            "desired_factors",
        }
        # 'output_factor_by_group', 'output_scale',   'mac_shift', 'kernel_scale', 'output_zero_point',
    return False


class QftRunner(OptimizationAlgorithm):
    """
        Implementing Quantization-aware Fine Tuning
        (TODO link to Application Note in Developer Zone)

    Args:
        model: model
        model_config: model_config
        model_native: model_native
        unbatched_train_dataset: Source (tf.dataset) of images (at least train_images) used for training.
        unbatched_eval_dataset: [optional] Source (tf.dataset) of images used for monitoring of metrics on held-up
            dataset.
        process_dataset:
        process_eval:
        work_dir:
        var_freeze_cond:
        supervised_loss_fn: Uni-variate function applied on student's prediction, to produce a non-distillation loss
            (using ground-truth labels).

        ==== Config params Advanced ====

        distillation_loss : Object specifying the final segment of the training graph -
                          computing "reconstruction loss" from tensors of teacher and student.

        base_learning_rate : The base learning rate used for the schedule calculation
                           (e.g., starting point for the decay).
                          Main parameter to experiment with;  to ensure convergence, start from small values
                           for architectures substantially different from well-performing zoo examples.

        batch_size : Number of images used together in each training step;
                        driven by GPU memory constraints (may need to be reduced to meet them)
                        but also by the algorithmic impact opposite to that of learning_rate.

        dataset_size : Number of distinct images required for training (repeated in each epoch).
                       Exception is thrown if the supplied train_dataset falls short of that amount.

        epochs : Number of epochs of all the epochs.

        warmup_epochs : duration of warmup phase, in epochs,
                     before the starting the main schedule phase (e.g. cosine-restarts).
        warmup_lr :  constant learning rate to be applied during the warmup phase.
                    Defaults to 1/4 the base learning rate

        decay_epochs : duration of the “decay period” in epochs.
                    In the default case of cosine restarts, rate decays to zero (with cosine functional form)
                    across this period, to be then restarted for the next period.
        decay_rate : decay factor of the learning rate at a beginning of “decay period”, from one to the next one.
                  In default case of cosine restarts, the factor of the value to which learning rate
                    is restarted vs. the rate at the beginning of the period.


        val_images : Number of held-up/validation images for evaluation between epochs.
        val_batch_size : batch size for the inter-epoch validation.

        bias_only :
        optimizer : set to 'sgd' to use simple Momentum, otherwise Adam will be used.

        supervised_proportion : factor for weighting the supervised loss against the distillation.

    """

    NAME = "Quantization-Aware Fine-Tuning"
    DEFAULT_LEARNING_RATE = DEFAULT_LEARNING_RATE

    def __init__(
        self,
        model,
        model_config,
        logger_level,
        model_native,
        unbatched_train_dataset,
        unbatched_eval_dataset=None,
        process_dataset=None,
        process_eval=None,
        work_dir=None,
        var_freeze_cond=None,
        supervised_loss_fn=None,
        **kwargs,
    ):
        super().__init__(model, model_config, name=self.NAME, logger_level=logger_level, **kwargs)
        self._work_dir = work_dir
        self._model_native = model_native
        self.unbatched_train_dataset = unbatched_train_dataset
        self.unbatched_eval_dataset = unbatched_eval_dataset

        ft_cfg = self.get_algo_config()
        self.ft_cfg = ft_cfg
        self.ft_cfg.train_scales = self.ft_cfg.train_scales and not self.ft_cfg.train_encoding

        # init var_freeze_cond
        var_freeze_cond = var_freeze_cond or (lambda s: False)

        def var_freeze_cond_function(s):
            return (
                var_freeze_cond(s)
                or (ft_cfg.bias_only and "kernel" in s)
                or (ft_cfg.train_encoding and _train_encoding_freeze(self._model, s))
            )

        self.var_freeze_cond = var_freeze_cond_function

        self.supervised_loss_fn = supervised_loss_fn or (lambda y, y_pred: 0)
        self._metrics = None
        self.process_dataset = process_dataset if process_dataset is not None else add_label_to_data
        self.process_eval = process_eval if process_eval is not None else add_label_to_data
        self._layers_info = {f"{layer}/successfully_run": False for layer in self._model.flow.toposort()}
        self._train_modes = None

    def export_statistics(self):
        return self._layers_info

    @property
    def steps_per_epoch(self):
        return self.ft_cfg.dataset_size // self.ft_cfg.batch_size

    @property
    def validation_steps(self):
        return self.ft_cfg.val_images // self.ft_cfg.val_batch_size

    @property
    def schedule_epochs(self):
        return self.ft_cfg.epochs - self.ft_cfg.warmup_epochs

    def get_hyperparams(self):
        """Actual params used for the run (whether by default or by config)"""
        params = dict(self.ft_cfg.__dict__)
        params["train_batch_size"] = self.ft_cfg.batch_size
        params["warmup_rate"] = self.ft_cfg.warmup_lr
        params["loss_types"] = [lt.value for lt in params["loss_types"]]
        params["optimizer"] = params["optimizer"].value
        return params

    def get_default_loss_layers(self, add_lca_default):
        """
        Search the last derivativable layers in the graph

        Use BFS search to scan the model from start to end to find the last derivativable
        layers in the model

        Return:
            list with the names of the last derivativable layers in the model

        """
        blocking_nodes = set()
        blockers_preds = list()

        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if not acceleras_layer.is_differentiable():
                blocking_nodes.add(layer)
                bloop = self._model.flow.descendants(layer)
                blocking_nodes.update(set(bloop))
                blocker_preds = self._model.flow.predecessors_sorted(layer)
                blockers_preds.extend(blocker_preds)
        trained_nodes = filter(lambda x: x not in blocking_nodes, self._model.flow.nodes)
        trained_graph = self._model.flow.subgraph(trained_nodes)
        loss_layers = [x for x in trained_graph.nodes() if trained_graph.out_degree(x) == 0 or x in blockers_preds]
        for loss_lname in loss_layers:
            loss_layer = self._model.layers[loss_lname]
            if (loss_layer.num_outputs != 1) and (trained_graph.out_degree(loss_lname) != 0):
                self._logger.warning(
                    f"The layer {loss_lname} was selected as loss layer, alongside with some of its descendants. "
                    f"Since the layer has multiple outputs, the loss term might be incorrect. "
                    f"Please consider setting up loss layers manually.",
                )

        if add_lca_default:  # Dead code, add as default after DR
            additional_loss_layers = self._find_layers_lowest_common_anscestors(self._model.flow, loss_layers)
            loss_layers.extend(additional_loss_layers)

        return loss_layers

    def _freeze_16bit_layers(self):
        for layer in self._model.layers.values():
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            if isinstance(layer, HailoConst):
                continue
            if layer.get_precision_mode() in {PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8}:
                if not isinstance(layer, BaseHailoConv):  # freeze weights of non-conv 16-bit layers
                    layer.trainable = False
                elif (
                    self.ft_cfg.train_scales
                ):  # train_scales is not allowed if there is any 16-bit quantized conv layer
                    raise AccelerasUnsupportedError("16 bit finetune is currently not supported with train_scales")

    @staticmethod
    def _find_layers_lowest_common_anscestors(model_flow, loss_layers):
        """
        Search for the lowest common ancestor of each pair of loss layers.
        Currently used for additional loss layers in case of multi output model.
        """
        if len(loss_layers) == 1:
            return []
        pairs = list(combinations(loss_layers, 2))
        lca_generator = nx.all_pairs_lowest_common_ancestor(model_flow, pairs=pairs)
        lca = {v for _, v in lca_generator}
        return list(lca)

    def _setup(self):
        self._logger.debug("Starting FineTune")
        self._freeze_16bit_layers()
        ft_cfg = self.get_algo_config()
        ft_cfg.loss_layer_names = ft_cfg.loss_layer_names or self.get_default_loss_layers(ft_cfg.add_lca_default)

        self.distillation_loss = DeepDistillLoss(
            ft_cfg.loss_layer_names,
            ft_cfg.loss_types,
            ft_cfg.loss_factors,
            default_loss=ft_cfg.def_loss_type,
            meta_arch=ft_cfg.meta_arch,
        )

        # init optimizer
        self.optimizer_gen = get_optimizer_gen(ft_cfg.optimizer)
        self.train_scheduler_gen = get_train_scheduler(
            lr_schedule_type=ft_cfg.lr_schedule_type,
            decay_rate=ft_cfg.decay_rate,
            steps_per_epoch=self.steps_per_epoch,
            decay_epochs=ft_cfg.decay_epochs,
            dataset_size=ft_cfg.dataset_size,
            t_mul=ft_cfg.t_mul,
        )
        self.warmup_scheduler_gen = get_warmup_scheduler(
            ft_cfg.warmup_strategy, ft_cfg.warmup_epochs, self.steps_per_epoch
        )

        train_dataset = self.unbatched_train_dataset.map(self.process_dataset)
        verify_dataset_size(train_dataset, ft_cfg.dataset_size, warning_if_larger=True, logger=self._logger)
        self.train_dataset = (
            train_dataset.take(ft_cfg.dataset_size)
            .shuffle(ft_cfg.shuffle_buffer_size)
            .repeat(ft_cfg.epochs)
            .batch(ft_cfg.batch_size)
        )

        self.eval_dataset = (
            None
            if (self.unbatched_eval_dataset is None or self.process_eval is None)
            else (
                self.unbatched_eval_dataset.map(self.process_eval)
                .take(ft_cfg.val_images)
                .repeat(ft_cfg.epochs)
                .batch(ft_cfg.val_batch_size)
            )
        )
        self._logger.info(f"Using dataset with {ft_cfg.dataset_size} entries for finetune")

        # set model lossy
        self._model.set_lossy(native_act=ft_cfg.native_activations == ThreeWayPolicy.enabled)

    def should_skip_algo(self):
        ft_cfg = self.get_algo_config()
        return ft_cfg.policy == FinetunePolicy.disabled

    def get_algo_config(self):
        return self._model_config.finetune

    def _compile(self, qft_distiller: distiller.Distiller, metrics):
        ft_cfg = self.get_algo_config()

        optimizers_and_layers = []

        if ft_cfg.train_weights:
            weights_lr = QFTScheduler(
                ft_cfg.warmup_epochs,
                self.steps_per_epoch,
                self.train_scheduler_gen(ft_cfg.learning_rate),
                self.warmup_scheduler_gen(ft_cfg.warmup_lr, ft_cfg.learning_rate),
            )
            weights_optimizer = self.optimizer_gen(weights_lr)
            optimizers_and_layers.append((weights_optimizer, list(self._model.layers.values())))

        if ft_cfg.train_encoding:
            learning_rate = ft_cfg.learning_rate / 10 if ft_cfg.train_weights else ft_cfg.learning_rate
            warmup_rate = ft_cfg.warmup_lr / 10 if ft_cfg.train_weights else ft_cfg.warmup_lr
            encoding_lr = QFTScheduler(
                warmup_epochs=ft_cfg.warmup_epochs,
                steps_per_epoch=self.steps_per_epoch,
                train_scheduler=self.train_scheduler_gen(learning_rate),
                warmup_scheduler=self.warmup_scheduler_gen(warmup_rate, learning_rate),
            )
            encoding_optimizer = self.optimizer_gen(encoding_lr)
            optimizers_and_layers.append((encoding_optimizer, self._model.model_encoding))

        optimizer = MultiOptimizer(optimizers_and_layers)
        qft_distiller.compile(
            optimizer=optimizer,
            metrics=metrics,
            supervised_loss_fn=self.supervised_loss_fn,
            distillation_loss_fn=self.distillation_loss.compare,
            loss_layers=ft_cfg.loss_layer_names,
            supervised_proportion=self.ft_cfg.supervised_proportion,
            var_freeze_cond=self.var_freeze_cond,
            stop_graident_at_loss_layers=ft_cfg.stop_gradient_at_loss,
            wraparound_factor=ft_cfg.wraparound_factor,
            train_weights=ft_cfg.train_weights,
            writer_mode=ft_cfg.log_debug_data,
        )

    def _run_int(self):
        self.run_qft(self._model_native, self._model, metrics=self.metrics)

    def _prep_run(self, hmodel_fp, hmodel_quant):
        if not self.ft_cfg.train_scales:
            hmodel_quant.disable_internal_encoding(force_endocing_layers=self.ft_cfg.loss_layer_names)
        else:
            hmodel_quant.enforce_encoding(training=True)
            BuildScaleVariables(hmodel_quant, self._model_config, self._logger_level, logger=self._logger).run()

        if self.ft_cfg.online_quantization_bias_fix:
            self._logger.info("Applying online bias correction")
            for layer in hmodel_quant.layers.values():
                if BiasCorrection.is_correctable(layer, hmodel_fp):
                    conv_ops = self._get_conv_ops(layer)
                    for conv_op in conv_ops:
                        conv_op.launch_online_bias_correction()

        if self.ft_cfg.force_pruning:
            for layer in hmodel_quant.layers.values():
                if isinstance(layer, BaseHailoNonNNCoreLayer):
                    continue
                layer.enable_force_pruning()
        for layer in hmodel_quant.layers.values():
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            if layer.built:
                layer._tracker.locked = False  # unlock the layer tracker to allow changes
                layer.enforce_internal_encoding()
                layer._tracker.locked = True  # lock the layer tracker after changes
        self.configure_wraparound_loss()
        self._train_modes = {}
        if self.ft_cfg.train_encoding:
            self._train_modes = self._update_quant_elements(self._train_modes)
            hmodel_quant.build([(self.ft_cfg.batch_size,) + shape for shape in hmodel_quant.get_input_shapes()])
            hmodel_quant.enable_encoding_infer()
        # Explicit build is required for the multi-optimizer to work
        hmodel_quant.build([(self.ft_cfg.batch_size,) + shape for shape in hmodel_quant.get_input_shapes()])

        self._original_trainable = {lname: layer.trainable for lname, layer in hmodel_quant.layers.items()}
        for layer in hmodel_quant.layers.values():
            layer.trainable = self.ft_cfg.train_weights and layer.trainable

        qft_distiller = distiller.Distiller(
            teacher=hmodel_fp,
            student=hmodel_quant,
            train_scales=self.ft_cfg.train_scales,
        )

        return qft_distiller

    def configure_wraparound_loss(self):
        ft_cfg = self.get_algo_config()
        wraparound_loss = ft_cfg.wraparound_factor > 0
        for layer in self._model.layers.values():
            if not isinstance(layer, BaseHailoLayer):
                continue
            layer.set_wraparound_loss(wraparound_loss)

    def disable_wraparound_loss(self):
        for layer in self._model.layers.values():
            if not isinstance(layer, BaseHailoLayer):
                continue
            layer.set_wraparound_loss(False)

    def _update_quant_elements(self, train_modes):
        if self.ft_cfg.train_weights:
            return train_modes
        for layer in self._model.layers.values():
            if not isinstance(layer, BaseHailoLayer):
                continue
            train_modes.update(layer.get_quant_element_train_mode())
            layer.update_quant_elements_train_mode("native")
            # TODO: do we want to add additional noise / ste elements?
            for op in layer.atomic_ops:
                train_mode = QuantTrainMode.NOISE
                if isinstance(op, AddBiasOp):
                    qelem = op.weight_lossy_elements.bias_decompose
                elif isinstance(op, ConvStrippedOp):
                    qelem = op.weight_lossy_elements.kernel
                elif isinstance(op, ActivationOp):
                    qelem = op.output_lossy_element
                elif isinstance(op, (ElementwiseAddDirectOp, ElementwiseAddOp)):
                    qelem = op.weight_lossy_elements.factor
                else:
                    continue
                qelem.train_mode = train_mode
        return train_modes

    def _restore_quant_elements(self, train_modes):
        for qe, train_mode in train_modes.items():
            qe.train_mode = train_mode

    def run_qft(self, hmodel_fp: HailoModel, hmodel_quant: HailoModel, metrics=None):
        """
        Runs the quantization-aware training as configured.

        Note - this can be a lengthy process depending on config (train_images, epochs, etc.),
              but can be cut short at any time by the user, with the ensuing quantized model technically usable,
              (though less accurate than it would be at the end of the full schedule),
              with its variables reflecting the modifications incurred until that time.

        Args:
            hmodel_fp: the "teacher", full-precision model serving as reference
            hmodel_quant: the "student", STE-quantized model undergoing training;
                           its variables undergoing gradient-driven modifications.
            metrics: keras metrics to track during training (e.g. see the accuracy improve)

        """
        self.main_train_history = FullHistory()

        # Finished constant LR warmup, let's proceed with the Cosine stage
        try:
            qft_distiller = self._prep_run(hmodel_fp, hmodel_quant)
            self._compile(qft_distiller, metrics)
            callbacks = [self.main_train_history]
            if self.ft_cfg.log_debug_data is not QFTWriterMode.disabled:
                tensorboard_writer = TensorboardWriter(
                    distiller=qft_distiller, name=f"{self._name}", workdir=self._work_dir
                )
                callbacks.append(tensorboard_writer)
            qft_distiller.build([(self.ft_cfg.batch_size,) + shape for shape in hmodel_quant.get_input_shapes()])
            self.main_train_summary_per_epoch = qft_distiller.fit(
                self.train_dataset,
                verbose=1,
                epochs=self.ft_cfg.epochs,
                steps_per_epoch=self.steps_per_epoch,
                validation_data=self.eval_dataset,
                validation_steps=self.validation_steps,
                validation_freq=100,  # self.decay_epochs,
                callbacks=callbacks,
            )

        except tf.errors.ResourceExhaustedError as e:
            self._logger.debug(e)
            raise AccelerasResourceError(
                f"GPU memory has been exhausted. Please try to use {self._name} with lower"
                f" batch size or run on CPU.",
            )

        except KeyboardInterrupt:
            self._logger.warning("Training cut by the user, proceed at your own peril")

        # (!!) EAGERIZE - recompute the modified scale tensors from variables in eager mode, so they are now numbers.
        # Otherwise, they'd stay tensors of training graph, as it's discarded they're orphans throwing
        # "leaked from graph" assertion.
        self.disable_wraparound_loss()
        self._restore_quant_elements(self._train_modes)
        if self.ft_cfg.train_encoding:
            hmodel_quant.disable_encoding_infer()
            hmodel_quant.enforce_encoding()

        if self.ft_cfg.online_quantization_bias_fix:
            for lname, layer in hmodel_quant.layers.items():
                if BiasCorrection.is_correctable(layer, hmodel_fp):
                    conv_ops = self._get_conv_ops(layer)
                    for conv_op in conv_ops:
                        conv_op.finalize_online_bias_correction()

        hmodel_quant.enable_internal_encoding()

        if self.ft_cfg.train_scales:  # QFT++ / QFT PAPER
            hmodel_quant.enforce_encoding(train_scales=True)

        for layer in hmodel_quant.layers.values():
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer.enforce_internal_encoding()

        for lname, layer in hmodel_quant.layers.items():
            layer.trainable = self._original_trainable[lname]

        for loss_layer in self.ft_cfg.loss_layer_names:
            self._layers_info[f"{loss_layer}/successfully_run"] = True
            for lname in hmodel_quant.flow.ancestors(loss_layer):
                self._layers_info[f"{lname}/successfully_run"] = True

    @staticmethod
    def _get_conv_ops(acceleras_layer):
        conv_ops = [op for op in acceleras_layer.atomic_ops if isinstance(op, ConvStrippedOp)]
        return conv_ops

    @property
    def metrics(self):
        return self._metrics

    @metrics.setter
    def metrics(self, metrics):
        self._metrics = metrics

    def finalize_global_cfg(self, algo_config):
        if algo_config.batch_size is None:
            algo_config.batch_size = self._model_config.calibration.batch_size
        if algo_config.learning_rate is None:
            algo_config.learning_rate = self.DEFAULT_LEARNING_RATE / DEFAULT_BATCH_SIZE * algo_config.batch_size
        if algo_config.warmup_lr is None:
            algo_config.warmup_lr = algo_config.learning_rate / 4
        if algo_config.shuffle_buffer_size == 0:
            algo_config.shuffle_buffer_size = algo_config.dataset_size
        if not self.should_skip_algo():
            self.check_dataset_length(algo_config, "dataset_size", self.unbatched_train_dataset)
            self.check_batch_size(algo_config, "dataset_size", "batch_size")
