import copy
from abc import ABC, abstractmethod

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.lossy_elements.clip_element import ClipElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.statistics import stats_collector
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DEFAULT_OPTIMIZATION_TARGET
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasValueError,
)
from hailo_model_optimization.acceleras.utils.dataset_util import rebuild_dataset_v2
from hailo_model_optimization.acceleras.utils.stats_export import extract_stats_sdk_format
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.targets.inference_targets import EmulationInferenceTargets


class GraphBuildException(Exception):
    """Compression exception."""


class GraphWrapper(ABC):
    def __init__(self):
        self._graphs = {}
        self._dataset = None
        self._batch_count = None

    @abstractmethod
    def init_graph(self, target, use_gpu, **kwargs):
        """
        it behaves slightly different between the two wrappers but as long we only use graph_wrapper for stats collection
        it should be fine
        """

    def set_calib_dataset(self, dataset, hn_model, calibset_size, batch_size):
        self._dataset = dataset.take(calibset_size).batch(batch_size)
        self._batch_count = calibset_size // batch_size
        self.validate_shapes(hn_model)

    def get_calib_dataset(self):
        return self._dataset

    def validate_shapes(self, hn_model):
        input_data = self._dataset.element_spec[0]
        if not isinstance(input_data, dict):
            layer_name = hn_model.get_input_layers()[0].name
            input_data = {layer_name: input_data}
        for key in input_data:
            layer = hn_model.get_layer_by_name(key)
            self._validate_single_shape(layer, input_data[key])

    @staticmethod
    def _validate_single_shape(layer, input_data):
        input_shape = tuple(layer.input_shape[1:])
        if layer.transposed and len(input_shape) == 3:
            input_shape = (input_shape[1], input_shape[0], input_shape[2])
        if not input_data.shape:
            default_logger().warning("Dataset signature has missing information, skipping shape validation")
            return
        data_shape = tuple(input_data.shape[1:])
        if input_shape == data_shape:
            return
        elif None in data_shape and len(input_shape) == len(data_shape):
            default_logger().warning("Dataset signature has missing information, skipping shape validation")
            return
        else:
            raise GraphBuildException(
                f"Data shape {data_shape} for layer {layer.name} doesn't match network's input shape {input_shape}",
            )

    @abstractmethod
    def update_graph_params(self, target_name, params):
        """
        Updated the params. This method assigns the new value.
        it behaves slightly different between the two wrappers but as long we only use graph_wrapper for stats collection
        it should be fine
           -tf_model wrapper updates the graph itself
           -acceleras wrapper only updates the params - and in the collect stats its updates the params

        """

    def get_graph(self, target_name):
        return self._graphs[target_name]

    @abstractmethod
    def collect_stats(self, target_name, layers_to_clip=None, double_stream=True, run_eagerly=False):
        """
        Collect results of all the layers in the network. This method accumulates native
        statistics over a number of batches.

        Returns:
            dictionary of calibration statistics per layer

        Args:
            target_name: the target name we would want to collect stats on
            layers_to_clip: a list of layers we want to clip there activation
            double_stream: True for (data, image_info) tf.dataset sources.
              if False, will assume a "single-stream" / "just-data" input, e.g. numpy array.
            run_eagerly: indicates if we need to run eagerly

        """

    @staticmethod
    def get_results_by_layer(calibration_stats_tensors, inference_results, prev_result_by_layer=None):
        pass


class GraphWrapperTfModel(GraphWrapper):
    def __init__(self, get_tf_graph_callback, logger=None):
        super().__init__()
        self._initializers = {}
        self._get_tf_graph = get_tf_graph_callback
        self._logger = logger or default_logger()

    def init_graph(self, target, use_gpu, **kwargs):
        sess_config = tf.compat.v1.ConfigProto(device_count={"GPU": 0}) if not use_gpu else None
        session = tf.compat.v1.Session(config=sess_config, graph=tf.Graph())
        with session.as_default(), session.graph.as_default():
            dataset_v1 = rebuild_dataset_v2(self._dataset)
            iterator = tf.compat.v1.data.make_initializable_iterator(dataset_v1)
            input_data, _ = iterator.get_next()
            graph_export, _, _ = self._get_tf_graph(target, custom_session=session, nodes=input_data, **kwargs)
        if target.name in self._graphs:
            self._logger.debug(f"Overriding {target.name} graph")
        self._graphs[target.name] = graph_export
        self._initializers[target.name] = iterator.initializer

    def update_graph_params(self, target_name, params):
        """
        Update the new graph parameters. This method assigns the new values of all the tensors in the graph.
        """
        # TODO: Is this update params good enough or should load_params be mimicked here?
        #  We can use TFModel's load_params, maybe this is the best solution?
        hailo_export = self._graphs[target_name]
        with hailo_export.session.graph.as_default(), hailo_export.session.as_default():
            assign_ops = []
            for tensor_name, value in params.items():
                if tensor_name == ModelParams.PARAMS_KIND_STR:
                    continue
                if any(const_name in tensor_name.split("/")[-1] for const_name in ModelParams.CONSTS_NAMES):
                    # ignoring constant values that shouldn't be updated after equalization - based on an assumption
                    # that consts really don't actually change during equalization shouldn't
                    continue
                var = tf.compat.v1.global_variables(tensor_name)
                if len(var) > 1:
                    raise ValueError("Multiple tensors with the same name")
                tensor = var[0]
                assign_ops.append(tensor.assign(value))
            hailo_export.session.run(assign_ops)

    def get_graph(self, target_name):
        return self._graphs[target_name]

    def get_initializer(self, target_name):
        return self._initializers[target_name]

    def has_variable(self, target_name, key):
        with self._graphs[target_name].graph.as_default():
            var = tf.compat.v1.global_variables(key)
        return len(var) != 0

    def collect_stats(self, target_name, layers_to_clip=None, double_stream=True, run_eagerly=False):
        hailo_export = self._graphs[target_name]
        if layers_to_clip:  # not None or list with 1 or more items
            custom_tensors = hailo_export.activations_histograms
        else:
            custom_tensors = None
        with hailo_export.session.as_default(), hailo_export.graph.as_default():
            hailo_export.session.run(self._initializers[target_name])
        return self._collect_stats(hailo_export, self._batch_count, custom_tensors)

    @classmethod
    def _collect_stats(cls, model, calib_num_batch, custom_tensor_list=None):
        """
        Collect native results of all the layers in the network. This method accumulates native
        statistics over a number of batches.

        Returns
            dictionary of calibration statistics per layer

        """
        tensor_list = custom_tensor_list if custom_tensor_list else model.calibration_stats
        sess = model.session
        results_by_layer = None
        for _i in range(calib_num_batch):
            with model.graph.as_default(), sess.as_default():
                native_results = sess.run(tensor_list)

            results_by_layer = cls.get_results_by_layer(
                calibration_stats_tensors=tensor_list,
                inference_results=native_results,
                prev_result_by_layer=results_by_layer,
            )

        for key in results_by_layer:
            if "_energy_features" in key:
                results_by_layer[key] /= calib_num_batch
            if "stats_non_zero_percent_features" in key:
                results_by_layer[key] /= calib_num_batch
        return results_by_layer

    @staticmethod
    def get_results_by_layer(calibration_stats_tensors, inference_results, prev_result_by_layer=None):
        """
        Prepare model statistics for
        :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.translate_params`.

        Args:
            calibration_stats_tensors (list of :obj:`tf.Tensor`): List of tensors requested by the
                SDK for statistics gathering. This list can be obtained by calling
                :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.get_tf_graph` and
                accessing the required tensor list via the
                :attr:`~hailo_sdk_common.export.hailo_graph_export.HailoGraphExport.calibration_stats`
                property.
            inference_results (list of :obj:`numpy.ndarray`): List of inference results
                corresponding to the ``calibration_stats_tensors`` given.
            prev_result_by_layer (dict): A previous return value of this function. If used, the
                statistics will be based on both previous and current input batches.

        Returns:
            dict: A dict where the keys are layers names and the values are the results.

        """
        tensor_names = [x.name for x in calibration_stats_tensors]
        # TODO: SDK-10099
        res_by_layer = dict(list(zip(tensor_names, inference_results)))

        if prev_result_by_layer is not None:
            for key in res_by_layer:
                if "/stats_max_" in key:
                    res_by_layer[key] = np.maximum(res_by_layer[key], prev_result_by_layer[key])
                elif "/stats_min_" in key:
                    res_by_layer[key] = np.minimum(res_by_layer[key], prev_result_by_layer[key])
                elif "/activation_hist" in key or "_energy_features" in key or "stats_non_zero_percent_features" in key:
                    res_by_layer[key] += prev_result_by_layer[key]
                else:
                    default_logger().warning(f"the {key} is not updated")

        return res_by_layer


class GraphWrapperAcceleras(GraphWrapper):
    def __init__(self, hn_dict, params, model_config, get_tf_graph_callback, logger=None):
        super().__init__()
        self._accerleras_build = True
        self._params_by_target = {}
        self._hn_dict = hn_dict
        self._model_config = model_config
        pre_quant_target = EmulationInferenceTargets.SDK_FP_OPTIMIZED
        self._params_by_target[pre_quant_target] = copy.deepcopy(params)
        self._logger = logger or default_logger()
        try:
            self._build_model(pre_quant_target)
        except AccelerasImplementationError as e:
            self._logger.debug(str(e))
            self._accerleras_build = False
        self._get_tf_graph = get_tf_graph_callback

    @property
    def accerleras_build(self):
        return self._accerleras_build

    def _build_model(self, target_name):
        if target_name not in self._params_by_target:
            raise GraphBuildException(f"cant build model {target_name} with out params")
        self._last_acceleras_model = HailoModel(self._hn_dict, optimization_target=DEFAULT_OPTIMIZATION_TARGET)
        self._last_acceleras_model.import_config(self._model_config)
        self._last_acceleras_model.build_with_params(self._params_by_target[target_name])

    def init_graph(self, target, use_gpu, **kwargs):
        """We build the graph export only untill we dont need anymore the conv inference in the graph export"""
        self._build_graph_export_only_for_now(target, use_gpu, **kwargs)

    def update_graph_params(self, target_name, params):
        """Updated the params and in the collect stats its updates the params"""
        self._params_by_target[target_name].update(params)

    def _get_last_model(self, target_name, layers_to_clip):
        """
        get last model layer- if there are no clipped layers we build the model.
        If there are we get the last model that was already built - and the stats are on it.
        """
        if not layers_to_clip:
            self._build_model(target_name)

    def collect_stats(self, target_name, layers_to_clip=None, double_stream=True, run_eagerly=False):
        self._get_last_model(target_name, layers_to_clip)
        self._set_clipping(target_name)
        stats_collector.collect_stats(
            self._last_acceleras_model,
            self._dataset,
            double_stream=double_stream,
            layers_to_handle=layers_to_clip,
            run_eagerly=run_eagerly,
            histogram_ctx=layers_to_clip is not None,
        )
        return extract_stats_sdk_format(self._last_acceleras_model, layers_to_clip)

    def _build_graph_export_only_for_now(self, target, use_gpu, **kwargs):
        # TODO: We need an export graph for the self.con_inference
        sess_config = tf.compat.v1.ConfigProto(device_count={"GPU": 0}) if not use_gpu else None
        session = tf.compat.v1.Session(config=sess_config, graph=tf.Graph())
        with session.as_default(), session.graph.as_default():
            graph_export, _, _ = self._get_tf_graph(target, custom_session=session, **kwargs)
        if target.name in self._graphs:
            self._logger.debug(f"Overriding {target.name} graph")
        self._graphs[target.name] = graph_export

    def _set_clipping(self, target_name):
        """
        set clipping params for the relevant layers
        """
        self._set_activation_clipping(target_name)
        self._set_weight_clipping(target_name)

    def _set_activation_clipping(self, target_name):
        """
        activation clipping for the relevant layers
        """
        params = self._params_by_target[target_name]
        params = {} if params is None else params

        acceleras_model = self._last_acceleras_model
        activation_clipping_params = {
            key: value for key, value in params.items() if "activation_clipping_values" in key
        }
        for key, values in activation_clipping_params.items():
            key_split = key.split("/", 2)[:-1]
            layer_name = "/".join(key_split)
            layer = acceleras_model.layers[layer_name]
            # Changing numeric configuration clip only on the output_op
            out_clip_element = self.get_activation_clip_element(values)
            out_clip_element.enable()
            for op, _ in layer.iterate_output_ops():
                op.set_output_lossy_element(out_clip_element)

    def _set_weight_clipping(self, target_name):
        """
        weight clipping for the relevant layers
        """
        params = self._params_by_target[target_name]
        params = {} if params is None else params
        acceleras_model = self._last_acceleras_model
        weight_clipping_params = {key: value for key, value in params.items() if "weights_clipping_values" in key}
        for key, values in weight_clipping_params.items():
            key_split = key.split("/", 2)[:-1]
            layer_name = "/".join(key_split)
            layer = acceleras_model.layers[layer_name]
            if not hasattr(layer, "conv_op"):
                raise AccelerasValueError(f"Can't clip weights {layer.name} because its not supported..")
            # Changing numeric configuration clip only..
            values = np.expand_dims(values, -1) if layer.conv_op.is_depthwise else np.expand_dims(values, 1)
            weight_clip_element = self.get_weights_clip_element(values)
            weight_clip_element.enable()
            layer.conv_op.weight_lossy_elements.kernel = weight_clip_element

    def has_variable(self, target_name, key):
        # a hack in acceleras to just return true
        return True

    @staticmethod
    def get_activation_clip_element(values):
        """
        activation clipping
        """
        return ClipElement(values[0], values[1])

    @staticmethod
    def get_weights_clip_element(values):
        """
        weight clipping
        """
        return ClipElement(values[0], values[1])
