from typing import Iterable

import tensorflow as tf
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.statistics.statistics_base import TypeStats
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasResourceError


class StatsCollectCtx:
    """
    Use in 'with' context statement to get statistics reset,
    then updated with all data fed while in this context
    """

    def __init__(self, acceleras_model, layers_to_handle, output_histogram_layers, preact_histogram_layers):
        self.acceleras_model = acceleras_model
        self.layers_to_handle = (
            layers_to_handle if layers_to_handle is not None else list(acceleras_model.layers.keys())
        )
        self._set_lossless = layers_to_handle is None
        self.output_histogram_layers = output_histogram_layers if output_histogram_layers is not None else set()
        self.preact_histogram_layers = preact_histogram_layers if preact_histogram_layers is not None else set()

    def __enter__(self):
        self.acceleras_model.set_native()
        for lname in self.layers_to_handle:
            layer = self.acceleras_model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            collect_output_hist = lname in self.output_histogram_layers
            collect_preact_hist = lname in self.preact_histogram_layers
            layer.start_stats_collection(output_hist=collect_output_hist, preact_hist=collect_preact_hist)

    def __exit__(self, exc_type, exc_val, exc_tb):
        for lname in self.layers_to_handle:
            layer = self.acceleras_model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer.stop_stats_collection()
        if self._set_lossless:
            self.acceleras_model.set_lossless(native_act=True)


class HistogramCtx:
    """
    Use in 'with' context statement to get statistics reset,
    then updated with all data fed while in this context
    """

    def __init__(self, acceleras_model, layers_to_handle, output_histogram_layers):
        self.acceleras_model = acceleras_model
        self.layers_to_handle = layers_to_handle
        self.output_histogram_layers = output_histogram_layers

    def __enter__(self):
        self.acceleras_model.set_native()
        for lname in self.layers_to_handle:
            layer = self.acceleras_model.layers[lname]
            for op, out_index in layer.iterate_output_ops():
                op.set_output_limvals(out_index)
                op.start_stats_collection(
                    stats_cfg=tuple([TypeStats.HISTOGRAM]),
                    collect_inputs=False,
                    collect_output=True,
                )

    def __exit__(self, exc_type, exc_val, exc_tb):
        for lname in self.layers_to_handle:
            layer = self.acceleras_model.layers[lname]
            for op, out_index in layer.iterate_output_ops():
                op.stop_stats_collection()
        self.acceleras_model.set_lossless(native_act=True)


def collect_stats(
    acceleras_model: HailoModel,
    calib_dataset: Iterable,
    double_stream=True,
    layers_to_handle=None,
    run_eagerly: bool = False,
    steps_per_execution: int = 1,
    total_entries=None,
    output_histogram_layers=None,
    preact_histogram_layers=None,
    histogram_ctx=False,
):
    """
    High-level stats collection utility:
    runs the model eagerly (!) on provided calibration set,
    stats are updated within the model's layer objects (as a "side-effect")..

    Args:
        calib_dataset:  tf.Dataset(..)
        acceleras_model: acceleras model
        double_stream: True for Zoo-standard (data, image_info) tf.dataset sources.
                       if False, will assume a "single-stream" / "just-data" input, e.g. numpy array.
        layers_to_handle :  which layers we want to collect stats on. If None then we use all the layers.
        run_eagerly: indicates if we need to run eagerly
        total_entries: number of images (optional)

    """
    if double_stream:
        calib_dataset = calib_dataset.map(lambda image, info: image)
    contex_stats = HistogramCtx if histogram_ctx else StatsCollectCtx

    # Set all ops shapes before calling start_stats_collection
    shapes = [(None,) + shape for shape in acceleras_model.get_input_shapes()]
    acceleras_model.compute_output_shape(shapes)

    pbar = tqdm(total=total_entries, dynamic_ncols=True, unit="entries", desc="Calibration")
    with contex_stats(acceleras_model, layers_to_handle, output_histogram_layers, preact_histogram_layers):
        acceleras_model.compile(
            run_eagerly=run_eagerly,
            steps_per_execution=steps_per_execution,
            jit_compile=False,  # JIT compile is not supported for stats collection
        )
        for preprocessed_data in calib_dataset:
            try:
                acceleras_model.predict_on_batch(preprocessed_data)
            except tf.errors.ResourceExhaustedError:
                raise AccelerasResourceError(
                    "GPU memory has been exhausted. \
                                             Please try to use lower batch size for calibration or run on CPU.",
                )
            if isinstance(preprocessed_data, dict):
                entries_count = next(iter(preprocessed_data.values())).shape[0]
            else:
                entries_count = preprocessed_data.shape[0]
            pbar.update(entries_count)
        pbar.refresh()  # flush the fine pbar state
        pbar.close()
