import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_external_pad import HailoExternalPad
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask import HailoSoftmaxMask
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask_on_mac import HailoSoftmaxMaskOnMac
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel


class SNRMetric(tf.keras.metrics.Metric):
    def __init__(self, name="snr", epsilon=1e-10, **kwargs):
        super().__init__(name=name, **kwargs)
        self.energy_initializer = tf.keras.initializers.Constant(epsilon)
        self.signal_energy = self.add_weight(name="signal_energy", initializer=self.energy_initializer)
        self.noise_energy = self.add_weight(name="noise_energy", initializer=self.energy_initializer)
        self.count = self.add_weight(name="count", initializer="zeros", dtype=tf.float32)

    def update_state(self, native, numeric, **kwargs):
        self.count.assign_add(1)

        self.signal_energy.assign_add(tf.reduce_mean(native**2))
        self.noise_energy.assign_add(tf.reduce_mean((native - numeric) ** 2))

    def result(self):
        base_name, i = self.name.split(":")
        # Adding info to the base_name of the enery field to allow serialization using h5 format.
        # (Same key can't be used to store value and additional internal dict)
        # Changing the key of internal field to
        return {
            f"{self.name}": 10.0 * tf.math.log(self.signal_energy / self.noise_energy) / tf.math.log(10.0),
            f"{base_name}_info/signal_energy:{i}": self.signal_energy / self.count,
            f"{base_name}_info/noise_energy:{i}": self.noise_energy / self.count,
        }

    def reset_states(self):
        self.signal_energy.assign(self.energy_initializer())
        self.noise_energy.assign(self.energy_initializer())
        self.count.assign(0)


class PartialSNRMetric(SNRMetric):
    def update_state(self, native, numeric, partial_numeric, **kwargs):
        super().update_state(native, partial_numeric)


class HistogramMetric(tf.keras.metrics.Metric):
    def __init__(self, limvals, name="histogram", nbins=512, **kwargs):
        super().__init__(name=name, **kwargs)
        self._limvals = (limvals[0] - 1, limvals[1] + 1) if limvals[0] == limvals[1] else limvals
        self._nbins = nbins

        self.hist_initializer = tf.keras.initializers.Constant(0)
        self.hist = self.add_weight(
            name="hist",
            initializer=self.hist_initializer,
            shape=(self._nbins,),
            dtype=tf.int64,
        )

    def update_state(self, input, *args, **kwargs):
        hist = tf.histogram_fixed_width(input, self._limvals, self._nbins, dtype=tf.int64)
        self.hist.assign_add(hist)

    def result(self):
        bins = tf.linspace(*self._limvals, self._nbins + 1)
        base_name, i = self.name.split(":")
        return {
            f"{base_name}/hist:{i}": tf.cast(self.hist, dtype=tf.int32),
            f"{base_name}/bin_edges:{i}": bins,
        }

    def reset_states(self):
        self.hist.assign(self.hist_initializer(shape=(self._nbins,), dtype=tf.float32))


class SampleMetric(tf.keras.metrics.Metric):
    _NUM_OF_SPECIAL_SAMPLES = 4

    _GET_MAX_INDEX = 0
    _GET_MIN_INDEX = 1

    def __init__(self, shape, name="sample", sample_size=1000, **kwargs):
        super().__init__(name=name, **kwargs)
        self._shape = shape
        self._sample_size = sample_size
        self._final_sample_size = self._sample_size + self._NUM_OF_SPECIAL_SAMPLES
        self._size = np.prod(self._shape)
        self._range = tf.range(self._final_sample_size)

        self.count = self.add_weight(name="count", initializer="zeros", dtype=tf.float32)
        self.samples = self.add_weight(
            name="samples",
            initializer="zeros",
            shape=(self._final_sample_size, 3),
            dtype=tf.float32,
        )

    @staticmethod
    def _get_index(func, tensor):
        """
        tf.argmax/argmin is much slower on gpu compared to taking reduce_max/reduce_min, equal, and where operations.
        """
        value = func(tensor)
        equal_tensor = tf.math.equal(tensor, value)
        return tf.where(equal_tensor)[0]

    def update_state(self, native, numeric, **kwargs):
        self.count.assign_add(1)

        # 1. Take random uniform sample size from the new batch
        random_indices = tf.random.uniform(shape=(self._sample_size,), minval=0, maxval=self._size, dtype=tf.int64)
        unravel_random_indices = tf.transpose(tf.unravel_index(random_indices, self._shape))

        index_max_native = self._get_index(tf.reduce_max, native)
        index_max_numeric = self._get_index(tf.reduce_max, numeric)
        index_min_native = self._get_index(tf.reduce_min, native)
        index_min_numeric = self._get_index(tf.reduce_min, numeric)

        unravel_indices = tf.concat(
            [unravel_random_indices, [index_max_native, index_max_numeric, index_min_native, index_min_numeric]],
            0,
        )

        sample_native = tf.gather_nd(native, unravel_indices)
        sample_numeric = tf.gather_nd(numeric, unravel_indices)
        sample_channels = tf.cast(tf.gather(unravel_indices, 3, axis=1), tf.float32)
        new_sample = tf.stack([sample_native, sample_numeric, sample_channels], axis=1)

        # 2. Combine both old sample and new sample with probability of (n - 1) / n, and 1 / n respectively
        switch_uniform_samples = tf.random.categorical(
            tf.math.log([[1 - 1 / self.count, 1 / self.count]]),
            self._sample_size,
            dtype=tf.int32,
        )[0]

        switch_max_native = tf.math.greater(new_sample[self._sample_size, 0], self.samples[self._sample_size, 0])
        switch_max_numeric = tf.math.greater(
            new_sample[self._sample_size + 1, 1],
            self.samples[self._sample_size + 1, 1],
        )
        switch_min_native = tf.math.less(new_sample[self._sample_size + 2, 0], self.samples[self._sample_size + 2, 0])
        switch_min_numeric = tf.math.less(new_sample[self._sample_size + 3, 1], self.samples[self._sample_size + 3, 1])
        switch_special_samples = tf.cast(
            [switch_max_native, switch_max_numeric, switch_min_native, switch_min_numeric],
            tf.int32,
        )

        switch = tf.concat([switch_uniform_samples, switch_special_samples], 0)

        stack_sample = tf.stack([self.samples, new_sample], axis=1)
        stack_switch = tf.stack([self._range, switch], axis=1)
        self.samples.assign(tf.gather_nd(stack_sample, stack_switch))

    def result(self):
        sample_native, sample_numeric, sample_channels = tf.unstack(self.samples, axis=1)
        base_name, i = self.name.split(":")
        return {
            f"{base_name}/native:{i}": sample_native,
            f"{base_name}/numeric:{i}": sample_numeric,
            f"{base_name}/channel:{i}": sample_channels,
        }

    def reset_states(self):
        self.count.assign(0)


class SparsityMetric(tf.keras.metrics.Metric):
    def __init__(self, shape, name="sparsity", **kwargs):
        super().__init__(name=name, **kwargs)
        self._size = np.prod(shape)

        self.native_zero_count = self.add_weight(name="native_zero_count", initializer="zeros", dtype=tf.float64)
        self.numeric_zero_count = self.add_weight(name="numeric_zero_count", initializer="zeros", dtype=tf.float64)
        self.count = self.add_weight(name="count", initializer="zeros", dtype=tf.float64)

    def update_state(self, native, numeric, **kwargs):
        self.count.assign_add(1)
        self.native_zero_count.assign_add(1 - tf.math.count_nonzero(native) / self._size)
        self.numeric_zero_count.assign_add(1 - tf.math.count_nonzero(numeric) / self._size)

    def result(self):
        base_name, i = self.name.split(":")
        if self.count == 0:
            return {
                f"{base_name}/native:{i}": 0.0,
                f"{base_name}/numeric:{i}": 0.0,
            }
        return {
            f"{base_name}/native:{i}": self.native_zero_count / self.count,
            f"{base_name}/numeric:{i}": self.numeric_zero_count / self.count,
        }

    def reset_states(self):
        self.count.assign(0)
        self.native_zero_count.assign(0)
        self.numeric_zero_count.assign(0)


class QuantizedLimvalsMetric(tf.keras.metrics.Metric):
    def __init__(self, scale, zero_point, shape, name="quantized_limvals", **kwargs):
        super().__init__(name=name, **kwargs)
        self._scale = tf.cast(scale, dtype=tf.float32)
        self._zero_point = tf.cast(zero_point, dtype=tf.float32)

        self.max_initializer = tf.keras.initializers.Constant(tf.float32.min)
        self.min_initializer = tf.keras.initializers.Constant(tf.float32.max)

        self.max = self.add_weight(name="max", initializer=self.max_initializer, dtype=tf.float32)
        self.min = self.add_weight(name="min", initializer=self.min_initializer, dtype=tf.float32)

    def update_state(self, native, numeric, **kwargs):
        quant_numeric = tf.round(self._zero_point + numeric / self._scale)
        upper_lim = tf.math.reduce_max(quant_numeric)
        lower_lim = tf.math.reduce_min(quant_numeric)
        self.max.assign(tf.math.maximum(self.max, upper_lim))
        self.min.assign(tf.math.minimum(self.min, lower_lim))

    def result(self):
        # cannot concat KerasVariable, casting to tf.Variable
        return tf.stack([tf.Variable(self.min), tf.Variable(self.max)])

    def reset_states(self):
        self.max.assign(self.max_initializer(dtype=tf.float32))
        self.min.assign(self.min_initializer(dtype=tf.float32))


class QuantHistogramMetric(HistogramMetric):
    BIT_LIMIT = 8

    def __init__(self, scale, zero_point, bits, name="quantized_histogram", **kwargs):
        max_value = 2**bits
        self._reduce_resolution = bits > self.BIT_LIMIT
        nbins = min(max_value, 2**self.BIT_LIMIT)
        super().__init__((0, max_value), name=name, nbins=nbins, **kwargs)
        self._scale = tf.cast(scale, dtype=tf.float32)
        self._zero_point = tf.cast(zero_point, dtype=tf.float32)

    def update_state(self, native, numeric, **kwargs):
        quant_numeric = tf.round(self._zero_point + numeric / self._scale)
        super().update_state(quant_numeric)

    def result(self):
        res = super().result()
        base_name, i = self.name.split(":")
        res[f"{base_name}/bin_edges:{i}"] = tf.cast(res[f"{base_name}/bin_edges:{i}"], tf.int32)
        if not self._reduce_resolution:
            base_name, i = self.name.split(":")
            res[f"{base_name}/unique_count:{i}"] = tf.math.count_nonzero(res[f"{base_name}/hist:{i}"])
        return res


class LATModel(tf.keras.Model):
    """
    A wraper for HailoModel that calculate both numeric and native result of each layer.
    This wrapper calculate on the go noise analasis for each layer.

    Args:
        base_model: the basic HailoModel to wrap

    """

    def __init__(self, base_model: HailoModel, native_model: HailoModel):
        super().__init__()
        self._model = base_model
        self._native_model = native_model
        self._metrics_classes = {}
        for lname, layer in self._model.layers.items():
            if isinstance(layer, (BaseHailoNonNNCoreLayer, HailoExternalPad, HailoSoftmaxMask, HailoSoftmaxMaskOnMac)):
                continue
            limvals = layer.get_original_output_limvals()
            shapes = layer.output_shapes
            scales = layer.output_scales
            zero_points = layer.output_zero_points
            bits = [elem.bits for elem in layer.get_output_lossy_elements()]
            for i in range(layer.num_outputs):
                self._metrics_classes[f"{lname}:{i}"] = [
                    SNRMetric(name=f"{lname}/snr:{i}"),
                    PartialSNRMetric(name=f"{lname}/partial_snr:{i}"),
                    HistogramMetric(limvals[i], name=f"{lname}/histogram:{i}"),
                    SampleMetric(shapes[i], name=f"{lname}/sample:{i}"),
                    SparsityMetric(shapes[i], name=f"{lname}/sparsity:{i}"),
                    QuantizedLimvalsMetric(scales[i], zero_points[i], shapes[i], name=f"{lname}/quantized_limvals:{i}"),
                    QuantHistogramMetric(scales[i], zero_points[i], bits[i], name=f"{lname}/quantized_histogram:{i}"),
                ]

    @staticmethod
    def _get_build_inputs(input_data):
        dataset_sample = next(iter(input_data))[0]
        if isinstance(dataset_sample, dict):
            # If the dataset is a dict, we need to extract the shape of each input
            build_inputs = {k: v.shape for k, v in dataset_sample.items()}
        else:
            build_inputs = dataset_sample.shape

        return build_inputs

    def build(self, input_data):
        # build native inner model to avoid xla errors
        build_inputs = self._get_build_inputs(input_data)
        if not self._native_model.built:
            self._native_model.build(build_inputs)
        if not self._model.built:
            self._model.build(build_inputs)

    def compile(self, **kwargs):
        jit_compile = self._model.is_jit_compile_supported() and self._native_model.is_jit_compile_supported()
        self._model.compile(jit_compile=jit_compile)
        self._native_model.compile(jit_compile=jit_compile)

    def metrics_result(self):
        """
        return a dict with the metrics results for each output of the model's layers.

        The dict's keys are in the following format: <layer_name>/<metric_name>:<output_index>

        The model metrics are:

        snr: The signal to noise ratio of the specific layer calculated on the output of the layer when all the layers
            are quantized.
        snr_info/signal_energy: The mean energy of the native signal.
        snr_info/noise_energy: The mean energy of the noise.

        partial_snr: The signal to noise ratio of the specific layer calculated on the output of the layer when only
            this layer is quantized.
        partial_snr_info/signal_energy: The mean energy of the native signal.
        partial_snr_info/noise_energy: The mean energy of the noise.

        histogram/hist: A full-precision histogram.
        histogram/bin_edges: The bin edges of the full-precision histogram.

        sample/native: A uniform distributed samples of the full-precision output.
        sample/numeric: The corresponding numeric value of the above samples.
        sample/channel: The corresponding channel index of the above samples.

        sparsity/native: The relative fraction of the number of zeros to the output size in the full-precision output.
        sparsity/numeric: The relative fraction of the number of zeros to the output size in the numeric output.

        quantized_limvals: The minimum and maximum values per channel in the quantize domain.

        quantized_histogram/hist: A histogram of the output in the quantize domain.
        quantized_histogram/bin_edges: The bin edges of the quantize histogram.
        quantized_histogram/unique_count: The number of unique values of the numeric output.
        """
        res = {}
        for metrics in self._metrics_classes.values():
            for matric in metrics:
                metric_result = matric.result()
                if isinstance(metric_result, dict):
                    for k, v in metric_result.items():
                        res[k] = v.numpy() if isinstance(v, tf.Tensor) else v
                else:
                    res[matric.name] = metric_result.numpy()
        return res

    def call(self, inputs):
        inputs = self._model.inputs_as_dict(inputs)
        if self._model.preproc_cb is not None:
            inputs = self._model.preproc_cb(inputs)
        inferred_native = dict()
        inferred_numeric = dict()
        common_layers = []
        for lname in self._model.flow.toposort():
            nat_type = type(self._native_model.layers.get(lname))
            quant_type = type(self._model.layers[lname])
            if lname in self._native_model.layers and nat_type is quant_type:
                common_layers.append(lname)
        quant_layers_sorted = list(self._model.flow.toposort())
        native_layers_sorted = list(self._native_model.flow.toposort())
        for lname in self._model.flow.toposort():
            if lname not in common_layers:
                continue

            n_ancestors = self._native_model.flow.ancestors(lname)
            q_ancestors = self._model.flow.ancestors(lname)

            # Run missing native layers
            nblock = n_ancestors - (inferred_native.keys())
            nblock = sorted(nblock, key=lambda k: native_layers_sorted.index(k))
            for missing_layer in nblock:
                current_inputs_native = self._native_model._get_layer_inputs(missing_layer, inputs, inferred_native)
                layer = self._native_model.layers[missing_layer]
                output_native = layer(current_inputs_native, fully_native=True)
                output_native = [output_native] if layer.num_outputs == 1 else output_native
                inferred_native[f"{missing_layer}"] = {i: out for i, out in enumerate(output_native)}

            # Run missing quant layers
            qblock = q_ancestors - (inferred_numeric.keys())
            qblock = sorted(qblock, key=lambda k: quant_layers_sorted.index(k))
            for missing_layer in qblock:
                current_inputs_numeric = self._model._get_layer_inputs(missing_layer, inputs, inferred_numeric)
                layer = self._model.layers[missing_layer]
                output_numeric = layer(current_inputs_numeric, fully_native=False)
                output_numeric = [output_numeric] if layer.num_outputs == 1 else output_numeric
                inferred_numeric[f"{missing_layer}"] = {i: out for i, out in enumerate(output_numeric)}

            # Run current native layer
            native_layer = self._native_model.layers[lname]
            current_inputs_native = self._native_model._get_layer_inputs(lname, inputs, inferred_native)
            output_native = native_layer(current_inputs_native, fully_native=True)
            output_native = [output_native] if native_layer.num_outputs == 1 else output_native
            inferred_native[f"{lname}"] = {i: out for i, out in enumerate(output_native)}

            # Run current quant layer
            quant_layer = self._model.layers[lname]
            current_inputs_numeric = self._model._get_layer_inputs(lname, inputs, inferred_numeric)
            output_numeric = quant_layer(current_inputs_numeric, fully_native=False)
            output_numeric = [output_numeric] if quant_layer.num_outputs == 1 else output_numeric
            inferred_numeric[f"{lname}"] = {i: out for i, out in enumerate(output_numeric)}

            # Run current quant layer with native input for noise analysis
            output_partial_numeric = quant_layer(current_inputs_native, fully_native=False)
            output_partial_numeric = (
                [output_partial_numeric] if quant_layer.num_outputs == 1 else output_partial_numeric
            )

            # log metrics
            for i, (native, numeric, partial_numeric) in enumerate(
                zip(output_native, output_numeric, output_partial_numeric),
            ):
                metrics = self._metrics_classes.get(f"{lname}:{i}", [])
                for metric in metrics:
                    metric.update_state(native, numeric, partial_numeric=partial_numeric)

        outputs_native = self._model._outputs_dict_to_list(inferred_native)
        outputs_numeric = self._model._outputs_dict_to_list(inferred_numeric)
        outputs_native = outputs_native[0] if len(outputs_native) == 1 else outputs_native
        outputs_numeric = outputs_numeric[0] if len(outputs_numeric) == 1 else outputs_numeric

        if self._model._postproc_cb is not None:
            outputs_native = self._postproc_cb(outputs_native)
            outputs_numeric = self._postproc_cb(outputs_numeric)

        return {
            "native": outputs_native,
            "numeric": outputs_numeric,
        }
