from abc import abstractmethod
from collections import namedtuple
from enum import Enum
from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError

STATS_TYPE_FLOAT = tf.float32
STATS_TYPE_INT = tf.int32


class TypeStats(Enum):
    MIN = "min"
    MAX = "max"
    ENERGY = "energy"
    MEAN = "mean"
    NON_ZERO_PERCENT = "non_zero_percent"
    HISTOGRAM = "histogram"
    DYNAMIC_HISTOGRAM = "dynamic_histogram"


BasicTypeTuple = tuple(x for x in tuple(TypeStats) if x not in {TypeStats.HISTOGRAM, TypeStats.DYNAMIC_HISTOGRAM})
Statistics = namedtuple("Statistics", [x.value for x in tuple(TypeStats)])


def scale_stats(stats: Statistics, scale: float):
    """
    Scales the statistics by the given scale factor
    Args:
        stats: the statistics to scale
        scale: the scale factor
    Returns:
        the scaled statistics
    """
    stats.min[...] *= scale
    stats.max[...] *= scale
    stats.energy[...] *= scale**2
    stats.mean[...] *= scale
    # stats.non_zero_percent doesn't need to be scaled
    clear_filed(stats, TypeStats.HISTOGRAM)  # histogram can be scaled but it's not used in the code
    clear_filed(stats, TypeStats.DYNAMIC_HISTOGRAM)  # histogram can be scaled but it's not used in the code
    return stats


def update_stats(stats: Statistics, arr: np.ndarray, feild: TypeStats, clear_cannot_update: bool = False):
    """
    Updates the statistics with the given array
    Args:
        stats: the statistics to update
        arr: the array to update the statistics with
        feild: the field to update
    Returns:
        the updated statistics
    """
    old_arr = getattr(stats, feild.value)
    old_arr[:] = arr.copy()

    if clear_cannot_update:
        clear_filed(stats, TypeStats.ENERGY)
        clear_filed(stats, TypeStats.MEAN)
        clear_filed(stats, TypeStats.HISTOGRAM)
        clear_filed(stats, TypeStats.DYNAMIC_HISTOGRAM)

    return stats


def clear_filed(stats: Statistics, feild: TypeStats):
    """
    Clears the given field in the statistics
    Args:
        stats: the statistics to clear
        feild: the field to clear
    Returns:
        the cleared statistics
    """
    arr = getattr(stats, feild.value)
    if arr is not None:
        arr[:] = None
    return stats


# We separate metrics by two types, those of having a number that is gradually aggregated, and those that are ratio based
class MetricType(Enum):
    AGGREGATE = "aggregate"  # For instance, minimum, maximum and histogram are metrics that aggregated by scanning the data and updating the metric
    # On the other hand, mean, mean square and non zero percent are metrics that are computed and divided by the number of batches (there ratio based)
    RATIO = "RATIO"


class StatsBase(tf.keras.metrics.Metric):
    """

    Statistics implementation using keras metrics

    In this class, we implement metrics that are attached to each atomic op.
    It is used by basic_atomic_op to compute some basic statistics for every layer in the network.
    Current statistics that inherit from this class are minimum, maximum, meansquare, non zero percent,
    and histogram. They are used for calibration.

    The metrics are vectorized, meaning that each metric is grouped into buckets by axis of their given tensor.
    The implementation of the metrics is based on keras metrics. Keras metrics already implement the init, update_state
    and reset_state API, and they also subclass the keras layer class. As a keras layer, the metrics state is
    represented using tensorflow variables, with the 'add_weight' method.

    The implementations below follow the 'grpah mode' instructions, and therefore can run in tensorflow graph mode.
    This means that their code is traced once called, and deferred to a run in an optimized tensorflow graph.
    At the moment, the metrics are running from 'call' function of base_atomic_op, and this is not a good practice
    in keras.

    The better way to work is by creating and initializing the metrics using keras layer 'build' function.
    Also, for now, the metrics state machine and management are handled by our specialized code.

    Inside basic_atomic_op, there is a state machine for handling the metrics (Moving states to RESET,
    COMPILETED, etc.). The state machine is meaningless when called with graph mode, as the code is running only
    once. It is only relevant when running in eager mode, where the tensorflow run is done together with the python run.
    In graph mode, this doesn't hold.

    Examples:
        >>> metric = MinimalValueByFeature(axis_to_accumulte=(0, ), metric_length=3, name="minimum by channel")
        >>> metric.update_state(tf.constant([[1,2,3,4,5], [6,7,8,9,10], [11, 12, 13, 14, 15]]
        >>> metric.result() # returns a tensor that has the minimum split by the first dimension - [1, 6, 11]

    The init method is the initialization function of the metrics, it initializes the metric with the initial value and in case of
    a ratio metrics, initializers the denominator

    Args:
            axis_to_accumulate: the axis by which we compute the metric
            metric_length: the metric size
            initializer_value: the inital value for the metric to initialize
            name: the metric name
            metric_type: the type of the metric, representing aggregted vs ratio metrics


    """

    def __init__(
        self,
        axis_to_accumulate: tuple,
        metric_length: int,
        initializer_value: Union[int, np.float32],
        name="stats_base",
        metric_type=MetricType.AGGREGATE,
        **kwargs,
    ):
        super(StatsBase, self).__init__(name=name, **kwargs)
        self._axis_to_accumulate = axis_to_accumulate
        self._metric_length = metric_length
        self._metric_type = metric_type
        self._initializer_value = initializer_value
        self._init_accumulated_statistic()
        if self._metric_type == MetricType.RATIO:
            self._init_accumulated_num_batches()

    def update_state(self, data_batch):
        """
        This method calls the spefcific metric update function. In case of a ratio metric, the denominator is incremented by one.

        Args:
            data_batch: the tensor with the next data to compute the metric on.

        """
        self._update_accumulated_statistic(data_batch)
        if self._metric_type == MetricType.RATIO:
            self._accumulated_num_batches.assign_add(1)

    def result(self):
        """
        This method returns the metric value. In case of aggregate metrics it just returns the metric computed, and in case of
        a ratio metric, it divides the computed metric by the denominator.

        """
        if self._metric_type == MetricType.AGGREGATE:
            return self._accumulated_statistic
        return tf.math.divide(self._accumulated_statistic, self._accumulated_num_batches)

    def reset_state(self, **kwargs):
        """
        This method resets the metric to its initial value.
        """
        self._reset_accumulated_statistic()
        if self._metric_type == MetricType.RATIO:
            self._reset_accumulated_num_batches()

    def _init_accumulated_statistic(self):
        self._accumulated_statistic_initializer = tf.keras.initializers.Constant(self._initializer_value)
        # shape is computed to differentiate between number based metrics (non zero percent) and channel based metrics
        shape = () if self._metric_length == 0 else (self._metric_length,)
        # dtype is computed to differentiate between integer metrics (histogram) and float metrics
        dtype = STATS_TYPE_INT if isinstance(self._initializer_value, int) else STATS_TYPE_FLOAT
        self._accumulated_statistic = self.add_weight(
            name="accumulated_statistic",
            initializer=self._accumulated_statistic_initializer,
            shape=shape,
            dtype=dtype,
        )

    def _init_accumulated_num_batches(self):
        self._accumulated_num_batches_initializer = tf.keras.initializers.Constant(0)
        self._accumulated_num_batches = self.add_weight(
            name="_accumulated_num_batches",
            initializer=self._accumulated_num_batches_initializer,
        )

    @abstractmethod
    def _update_accumulated_statistic(self, data_batch):
        """
        This method is the specific part that each metrics implements for itself.
        For instance, mimimun is updating the current value to be the minimum between itself and the minimum of the next data batch. Mean is updated by summing
        the channel value and counting the number of batches.
        """

    def _reset_accumulated_statistic(self):
        self._accumulated_statistic.assign(
            self._accumulated_statistic_initializer(
                shape=(self._accumulated_statistic.shape),
                dtype=self._accumulated_statistic.dtype,
            ),
        )

    def _reset_accumulated_num_batches(self):
        self._accumulated_num_batches.assign(
            self._accumulated_num_batches_initializer(
                shape=self._accumulated_num_batches.shape,
                dtype=self._accumulated_statistic.dtype,
            ),
        )


class MinimalValueByFeature(StatsBase):
    """
    keeps the minimum value channelwise
    """

    def __init__(self, axis_to_accumulate, metric_length, name="minimal_value_by_feature", **kwargs):
        super(MinimalValueByFeature, self).__init__(
            axis_to_accumulate=axis_to_accumulate,
            metric_length=metric_length,
            initializer_value=tf.float32.max,
            name=name,
            **kwargs,
        )

    def _update_accumulated_statistic(self, data_batch):
        statistic = tf.reduce_min(data_batch, self._axis_to_accumulate)
        statistic = tf.reshape(statistic, [-1], "Flatten_Min")
        self._accumulated_statistic.assign(tf.math.minimum(self._accumulated_statistic, statistic))


class MaximalValueByFeature(StatsBase):
    """
    keeps the maximum value channelwise
    """

    def __init__(self, axis_to_accumulate, metric_length, name="maximal_value_by_feature", **kwargs):
        super(MaximalValueByFeature, self).__init__(
            axis_to_accumulate=axis_to_accumulate,
            metric_length=metric_length,
            initializer_value=STATS_TYPE_FLOAT.min,
            name=name,
            **kwargs,
        )

    def _update_accumulated_statistic(self, data_batch):
        statistic = tf.reduce_max(data_batch, self._axis_to_accumulate)
        statistic = tf.reshape(statistic, [-1], "Flatten_Max")
        self._accumulated_statistic.assign(tf.math.maximum(self._accumulated_statistic, statistic))


class MeanByFeature(StatsBase):
    """
    keeps the mean square value channelwise
    """

    def __init__(self, axis_to_accumulate, metric_length, name="mean_value_by_feature", **kwargs):
        super(MeanByFeature, self).__init__(
            axis_to_accumulate=axis_to_accumulate,
            metric_length=metric_length,
            initializer_value=0.0,
            name=name,
            metric_type=MetricType.RATIO,
            **kwargs,
        )

    def _update_accumulated_statistic(self, data_batch):
        statistic = tf.reduce_mean(data_batch, self._axis_to_accumulate)
        statistic = tf.reshape(statistic, [-1], "Flatten_Mean")
        self._accumulated_statistic.assign_add(statistic)


class MeanSquareByFeature(StatsBase):
    """
    keeps the mean square value channelwise
    """

    def __init__(self, axis_to_accumulate, metric_length, name="mean_square_value_by_feature", **kwargs):
        super(MeanSquareByFeature, self).__init__(
            axis_to_accumulate=axis_to_accumulate,
            metric_length=metric_length,
            initializer_value=0.0,
            name=name,
            metric_type=MetricType.RATIO,
            **kwargs,
        )

    def _update_accumulated_statistic(self, data_batch):
        statistic = tf.reduce_mean(data_batch**2, self._axis_to_accumulate)
        statistic = tf.reshape(statistic, [-1], "Flatten_Mean_Square")
        self._accumulated_statistic.assign_add(statistic)


class NonZeroPercentByFeature(StatsBase):
    """
    keeps the mean square value channelwise
    """

    def __init__(self, axis_to_accumulate, metric_length, name="non_zero_percent_by_feature", **kwargs):
        super(NonZeroPercentByFeature, self).__init__(
            axis_to_accumulate=axis_to_accumulate,
            metric_length=0,
            initializer_value=0.0,
            name=name,
            metric_type=MetricType.RATIO,
            **kwargs,
        )

    def _update_accumulated_statistic(self, data_batch):
        statistic = tf.cast(tf.math.count_nonzero(data_batch), dtype=STATS_TYPE_FLOAT) / tf.cast(
            tf.reduce_sum(tf.ones_like(data_batch)),
            dtype=STATS_TYPE_FLOAT,
        )
        self._accumulated_statistic.assign_add(statistic)


class HistogramByFeatures(StatsBase):
    """
    keeps the histogram of all the tensor between the hist_ranges.
    """

    def __init__(self, axis_to_accumulate, metric_length, name="histogram_by_feature", **kwargs):
        self._hist_range = None
        self._nbins = 1000
        super(HistogramByFeatures, self).__init__(
            axis_to_accumulate=axis_to_accumulate,
            metric_length=self._nbins,
            initializer_value=0,
            name=name,
            **kwargs,
        )

    def _update_accumulated_statistic(self, data_batch):
        statistic = tf.histogram_fixed_width(data_batch, self._hist_range, self._nbins)
        self._accumulated_statistic.assign_add(statistic)

    def reset_state(self, hist_range=None, nbins=1000, **kwargs):
        if not (isinstance(hist_range, tuple) and len(hist_range) == 2):
            raise AccelerasValueError("hist_range must be of type tuple")
        self._hist_range = hist_range
        self._nbins = nbins
        super(HistogramByFeatures, self).reset_state(**kwargs)


class DynamicHistogram(StatsBase):
    def __init__(self, axis_to_accumulate: tuple, metric_length: int, name="dynamic_histogram_by_feature", **kwargs):
        super().__init__(axis_to_accumulate, metric_length=1000, initializer_value=0.0, name=name, **kwargs)
        self._current_range = tf.Variable((np.nan, np.nan))
        self._critical_section = tf.CriticalSection()

    @tf.function
    def _update_accumulated_statistic(self, data_batch):
        # This code is overcomplicated for two reasons:
        # 1. tf didn't handle a race condition properly and I had conflicting access of the _current range
        # that is, 2 batchs tried to access self._current_range in the same time.
        # 2. tf.histogram_fixed_width didn't work properly with our standard batch size (8), so I'm using a workaround
        # everything worked fine in eager mode, but hell went loose when I used graph mode...

        # data_batch can be inf in histogram only if we have a softmax layer with a mask. then tf.math.log(1/2**(15+1))
        # will always be a small enugu number to be the minimum of the histogram, without affecting the results.
        data_batch = tf.where(tf.math.is_inf(data_batch), tf.math.log(1 / 2 ** (15 + 1)), data_batch)
        curr_min = tf.reduce_min(data_batch)
        curr_max = tf.reduce_max(data_batch)

        @tf.function
        def _update_hist():
            if tf.math.is_nan(self._current_range[0]) and tf.math.is_nan(self._current_range[1]):
                new_range = (curr_min, curr_max)
            elif curr_min < self._current_range[0] or curr_max > self._current_range[1]:
                new_range = self._update_hist_range(curr_min, curr_max)
            else:
                new_range = self._current_range[0], self._current_range[1]
            self._current_range.assign(new_range)

            if self._current_range[0] == self._current_range[1]:
                bins = tf.zeros_like(data_batch, dtype=tf.int32)
            else:
                bins = tf.histogram_fixed_width_bins(data_batch, self._current_range, self._metric_length)
            bins, _, count = tf.unique_with_counts(tf.reshape(bins, [-1]))
            bins = tf.reshape(bins, [-1, 1])
            new_stats = tf.tensor_scatter_nd_add(self._accumulated_statistic, bins, tf.cast(count, STATS_TYPE_FLOAT))

            self._accumulated_statistic.assign(new_stats)
            return self._accumulated_statistic

        self._critical_section.execute(_update_hist)

    def _update_hist_range(self, curr_min, curr_max):
        old_bins = self._get_bins(self._current_range[0], self._current_range[1])
        new_max = tf.maximum(curr_max, self._current_range[1])
        new_min = tf.minimum(curr_min, self._current_range[0])
        new_bins = self._get_bins(new_min, new_max)
        self._cast_hist(old_bins, new_bins)
        return new_min, new_max

    def _get_bins(self, range_min, range_max):
        bin_size = (range_max - range_min) / self._metric_length
        bins_offset = tf.range(0, self._metric_length, dtype=STATS_TYPE_FLOAT) * bin_size
        bins_min = bins_offset + range_min
        bins_max = bin_size + bins_offset + range_min
        return tf.stack([bins_min, bins_max], -1)

    def _cast_hist(self, old_bins, new_bins):
        old_hist = self._accumulated_statistic
        new_bins = tf.expand_dims(new_bins, -1)
        old_bins = tf.expand_dims(old_bins, 0)
        # calulate overlap ratio
        if self._current_range[0] == self._current_range[1]:
            # in case old bin size is 0, add all to relevent bin
            mat_min = tf.math.less_equal(new_bins[:, 0], old_bins[..., 0])
            mat_max = tf.math.greater(new_bins[:, 1], old_bins[..., 1])
            ratios = tf.cast(tf.math.logical_and(mat_min, mat_max), dtype=STATS_TYPE_FLOAT)
        else:
            mat_min = tf.maximum(new_bins[:, 0], old_bins[..., 0])
            mat_max = tf.minimum(new_bins[:, 1], old_bins[..., 1])
            old_bin_size = (self._current_range[1] - self._current_range[0]) / self._metric_length
            ratios = tf.maximum(0.0, mat_max - mat_min) / old_bin_size

        new_hist = tf.reduce_sum(ratios * old_hist, -1)
        self._accumulated_statistic.assign(new_hist)
