import os

import tensorflow as tf
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasResourceError,
    InferenceError,
)
from hailo_model_optimization.acceleras.utils.distributed_utils import DistContextInfo
from hailo_model_optimization.acceleras.utils.flow_state.updater import FlowCommands, update_model_state
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.acceleras.utils.numpy_utils import to_numpy


def get_build_inputs(dataset):
    dataset_sample = dataset if isinstance(dataset, dict) else next(iter(dataset))[0]
    if isinstance(dataset_sample, dict):
        # If the dataset is a dict, we need to extract the shape of each input
        build_inputs = {k: [1, *v.shape[1:]] for k, v in dataset_sample.items()}
    else:
        build_inputs = tuple([1, *dataset_sample.shape[1:]])

    return build_inputs


class TQDMCallback(tf.keras.callbacks.Callback):
    """
    a keras callback class for using tqdm on inference
    """

    def __init__(self, data_count, batch_size):
        self._pbar = tqdm(total=data_count, dynamic_ncols=True, unit="entries", desc="Inference")
        self._batch_size = batch_size
        super().__init__()

    def on_predict_batch_end(self, batch, logs=None):
        self._pbar.update(self._batch_size)

    def pbar_close(self):
        self._pbar.refresh()  # flush the fine pbar state
        self._pbar.close()


class InferenceModel(tf.keras.Model):
    dist_info: DistContextInfo = None

    def __init__(self):
        super().__init__()
        self.trainable = False

    @property
    def strategy(self):
        return self._strategy

    @strategy.setter
    def strategy(self, value):
        if self._strategy is not value:
            self._strategy = value
            self.strategy_changed = True
        else:
            self.strategy_changed = False

    @property
    def is_trainable(self):
        return self.trainable

    def build(self, dataset):
        """
        build the model with the given inputs
        """
        if not self._model.built:
            self._model.build(get_build_inputs(dataset))

    def run(self, unbatched_dataset, batch_size, data_count):
        if data_count:
            unbatched_dataset = unbatched_dataset.take(data_count)
        dataset = unbatched_dataset.batch(batch_size)
        with tqdm(total=data_count, desc="Processed", unit="images", disable=None) as pbar:
            self.build(dataset)
            # callback = TQDMCallback(data_count, batch_size)
            jit_compile = self.is_jit_compile_supported()
            self.compile(jit_compile=jit_compile)

            @tf.function(jit_compile=jit_compile, reduce_retracing=True)
            def predict_function(data):
                # call self since HWInferenceModel has no _model, works for all inference models
                return self(data, training=False)

            infer_outputs = []
            try:
                for item in dataset:
                    batch = item[0] if isinstance(item, (tuple, list)) else item
                    output_tensors = predict_function(batch)
                    infer_outputs.append(output_tensors)
                    current_batch_size = (
                        output_tensors[0].shape[0] if isinstance(output_tensors, list) else output_tensors.shape[0]
                    )
                    pbar.update(current_batch_size)
                    pbar.refresh()  # flush the fine pbar state

            except tf.errors.ResourceExhaustedError:
                pbar.close()
                raise AccelerasResourceError(
                    "GPU memory has been exhausted. \
                                            Please try to use lower batch size for inference or run on CPU.",
                )
            # callback.pbar_close()
            except KeyboardInterrupt:
                pbar.close()
                raise KeyboardInterrupt("Inference interrupted by user, displaying partial results")

            except Exception as e:
                pbar.close()
                raise InferenceError(
                    f"An error occurred during inference: {e}. Please check the model and input data.",
                )

        # Handle different output structures
        concatenated_outputs = []
        first_output = infer_outputs[0]

        if isinstance(first_output, tf.Tensor):
            concatenated_outputs = tf.concat(infer_outputs, axis=0)
        elif isinstance(first_output, dict):
            concatenated_outputs = {key: tf.concat([b[key] for b in infer_outputs], axis=0) for key in first_output}
        else:  # if isinstance(first_output, (tuple, list)):
            concatenated_outputs = [tf.concat([b[i] for b in infer_outputs], axis=0) for i in range(len(first_output))]

        return to_numpy(concatenated_outputs)

    def fit(self, **kwargs):
        raise InferenceError("Training is not supported in inference model")


class HWInferenceModel(InferenceModel):
    """
    a class to wrap HW inference with Keras model.
    The run command should be used within the correct context of the HailoRT

    Args:
            hailo_export: (:class: `~hailo_sdk_common.export.hailo_graph_export.DummyGraphExport)
                hailo export wrapper for HW inference

    """

    def __init__(self, hailo_export):
        super().__init__()
        self._hw_infer_wrapper = hailo_export.hef_infer_wrapper
        self._output_types = hailo_export.output_types

    def _tf_function_hw_inference(self, input_tensors):
        out = tf.numpy_function(
            self._hw_infer_wrapper.tf_infer,
            input_tensors,
            self._output_types,
            name="infer_hw_py_func",
        )
        return out[0] if len(out) == 1 else out

    def build(self, inputs):
        # no build required for HW inference
        pass

    def call(self, inputs):
        inputs_list = []
        if isinstance(inputs, dict):
            inputs_list = [
                inputs[input_name]
                for input_name in (
                    self._hw_infer_wrapper.infer_model.input_names
                    if self._hw_infer_wrapper.infer_model
                    else self._hw_infer_wrapper._input_names  # remove this line when async api on hailo8 is enabled SDK-51150
                )
            ]
        else:
            inputs_list = [inputs]
        return self._tf_function_hw_inference(inputs_list)

    def is_jit_compile_supported(self, training=False):
        return False


class SimulationInferenceModel(InferenceModel):
    """
    a class that in charge of acceleras inference

    Args:
        model: Mutable, model for inference

    """

    def __init__(self, model: HailoModel):
        super().__init__()
        self._model = model

    @property
    def model(self):
        return self._model

    def call(self, inputs, **kwargs):
        self.build(inputs)  # make sure the model is built
        return self._model(inputs, **kwargs)

    def _online_zp_compensation(self, online_zp_compensation=False):
        """
        set either matmul layers run expected in HW (no online zp compensation)
        """
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            if isinstance(layer, HailoMatmul) and layer.zp_comp_added:
                layer.matmul_op.online_zp_compensation = online_zp_compensation

    def _set_quantized(self):
        for layer in self._model.layers.values():
            for op in layer.atomic_ops:
                op.fully_native = False
                op.enable_lossy()

    def _set_internal_encoding(self, internal_encoding):
        if internal_encoding:
            self._model.enable_internal_encoding()
        else:
            self._model.disable_internal_encoding()

    def set_model_quantized(self, online_zp_compensation=False, internal_encoding=False):
        self._set_quantized()
        self._set_internal_encoding(internal_encoding)
        self._online_zp_compensation(online_zp_compensation=online_zp_compensation)

    def set_native(self, value=True):
        if value:
            for layer in self._model.layers.values():
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
        else:
            for layer in self._model.layers.values():
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.enable_lossy()

    def set_bit_exact(self, value=True):
        if value and not self._model.bit_exact_emulation_supported:
            raise AccelerasImplementationError("Not all layers in the models support bit exact")
        if value:
            self._model.disable_internal_encoding()
        self._model.bit_exact = value
        # eager run is more accurate which is needed to bit exact emulation
        self.compile(run_eagerly=value)

    @property
    def bit_exact_supported(self) -> bool:
        return self._model.bit_exact_supported

    @property
    def bit_exact_emulation_supported(self) -> bool:
        return self._model.bit_exact_emulation_supported

    def summary(self, **kwargs):
        return self._model.summary(**kwargs)

    def set_quantized(self, online_zp_compensation=False, internal_encoding=False):
        run_norm = os.getenv("ONLY_FOR_RESEARCH", default="false")
        if run_norm != "false":
            self.set_quantized_norm_research(
                online_zp_compensation=online_zp_compensation,
                internal_encoding=internal_encoding,
            )
        else:
            self.set_model_quantized(online_zp_compensation=online_zp_compensation, internal_encoding=internal_encoding)

    def set_lossless(self, exclude_activation=False):
        self._model.enable_internal_encoding()
        self._model.set_lossless(native_act=exclude_activation)

    def custom_infer_config(self, flow_commands: FlowCommands):
        update_model_state(self._model, flow_commands, default_logger())

    ### for reserach norm
    def _set_native_only_norm_numeric(self):
        for layer in self._model.layers.values():
            if "layer_normalization" in layer.full_name:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.enable_lossy()
                    op.bit_exact = True
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                print(layer.full_name, "native")

    def _set_quantized_only_norm_native(self):
        for layer in self._model.layers.values():
            if "layer_normalization" in layer.full_name:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                print(layer.full_name, "############################################ native")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.enable_lossy()
                print(layer.full_name, "numeric")

    def _set_quantized_norm(self):
        for layer in self._model.layers.values():
            for op in layer.atomic_ops:
                op.fully_native = False
                op.enable_lossy()
                if "layer_normalization" in layer.full_name:
                    op.bit_exact = True
            print(layer.full_name, "#########numeric")

    def _set_quantized_softmax_special(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                layer.fully_native = False
                layer.enable_lossy()
                print(f"############################################ {layer.full_name} lossy ")
                for op in layer.atomic_ops:
                    # if 'exp_numerator' in op.name or 'exp_denominator' in op.name:
                    if "de" in op.name or "passthru_op" in op.name:
                        op.fully_native = True
                        op.disable_lossy()
                        print(f"############################################ {op.full_name} native with op!")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_quantized_softmax_no_output_and_act(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                layer.fully_native = False
                layer.enable_lossy()
                print(f"############################################ {layer.full_name} lossy ")
                for op in layer.atomic_ops:
                    if "passthru_op" in op.name or "act_op" in op.name:
                        op.fully_native = True
                        op.disable_lossy()
                        print(f"############################################ {op.full_name} native with op!")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()

    def _set_quantized_softmax_io(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                for op in layer.atomic_ops:
                    if "passthru_op" in op.name or "input_op" in op.name:
                        op.fully_native = False
                        op.enable_lossy()
                        print(f"############################################ {op.full_name} lossy op!")
                    else:
                        op.fully_native = True
                        op.disable_lossy()
                        print(f"############################################ {op.full_name} native op!")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()

            # layer.enforce_internal_encoding()

    def _set_quantized_softmax_output(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                for op in layer.atomic_ops:
                    if "passthru_op" in op.name:
                        op.fully_native = False
                        op.enable_lossy()
                        print(f"############################################ {op.full_name} lossy op!")
                    else:
                        op.fully_native = True
                        op.disable_lossy()
                        print(f"############################################ {op.full_name} native op!")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()

    def _set_quantized_softmax_input(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                for op in layer.atomic_ops:
                    if "passthru_op" in op.name:
                        op.fully_native = False
                        op.enable_lossy()
                        print(f"############################################ {op.full_name} lossy op!")
                    else:
                        op.fully_native = True
                        op.disable_lossy()
                        print(f"############################################ {op.full_name} native op!")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()

    def _set_quantized_softmax_output_native(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                layer.fully_native = False
                layer.enable_lossy()
                print(f"############################################ {layer.full_name} lossy ")
                for op in layer.atomic_ops:
                    if "passthru_op" in op.name:
                        op.fully_native = True
                        op.disable_lossy()
                        print(f"############################################ {op.full_name} native with op!")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()

    def _set_quantized_softmax_special_inverse(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.name:
                layer.fully_native = True
                layer.disable_lossy()
                print(f"############################################ {layer.full_name} native ")
                for op in layer.atomic_ops:
                    # if 'exp_numerator' in op.name or 'exp_denominator' in op.name:
                    # if "softmax" in op.name:
                    if "passthru_op" in op.name:
                        op.fully_native = False
                        op.enable_lossy()
                        print(f"############################################ {op.full_name} lossy op!")
                    # if "input_op" in op.name or "passthru_op" in op.name:
                    #     op.fully_native = False
                    #     op.enable_lossy()
                    #     print(f"############################################ {op.full_name} lossy op!")

            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_lossless_softmax_num(self):
        for layer in self._model.layers.values():
            list_of_layers = [
                "softmax1",
                "softmax2",
                "softmax3",
                "softmax4",
                "softmax5",
                "softmax6",
                "softmax7",
                "softmax8",
                "softmax9",
                "softmax10",
                "softmax11",
                "softmax12",
            ]
            # list_of_layers = ["softmax1"]
            if layer.name in list_of_layers:
                for op in layer.atomic_ops:
                    # if 'exp_numerator' in op.name or 'exp_denominator' in op.name:
                    if "exp_numerator" in op.name:
                        op.fully_native = True
                        op.disable_lossy()
                        # fully_native_op = op
                        print(f"############################################ {op.full_name} native with op!")
                    else:
                        op.fully_native = False
                        op.disable_lossy()
                print(f"############################################ {layer.full_name} lossless ")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_lossless_softmax(self):
        for layer in self._model.layers.values():
            list_of_layers = [
                "softmax1",
                "softmax2",
                "softmax3",
                "softmax4",
                "softmax5",
                "softmax6",
                "softmax7",
                "softmax8",
                "softmax9",
                "softmax10",
                "softmax11",
                "softmax12",
            ]
            # list_of_layers = ["softmax1"]
            if layer.name in list_of_layers:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.disable_lossy()
                print(f"############################################ {layer.full_name} lossless with ")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_lossless_softmax1(self):
        for layer in self._model.layers.values():
            # list_of_layers = ["softmax1", "softmax2", "softmax3", "softmax4", "softmax5", "softmax6", "softmax7", "softmax8", "softmax9", "softmax10", "softmax11", "softmax12"]
            list_of_layers = ["softmax1"]
            if layer.name in list_of_layers:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.disable_lossy()
                print(f"############################################ {layer.full_name} lossless with ")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_quantized_softmax(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.full_name:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.enable_lossy()
                print(layer.full_name, "############################################ numeric")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_quantized_softmax_1(self):
        list_of_layers = ["softmax10"]
        # list_of_layers = ["softmax1"]

        for layer in self._model.layers.values():
            if layer.name in list_of_layers:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.enable_lossy()
                print(layer.full_name, "############################################ numeric")
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                # print(layer.full_name, "native")

    def _set_quantized_softmax_exact(self):
        for layer in self._model.layers.values():
            if "softmax" in layer.full_name:
                for op in layer.atomic_ops:
                    op.fully_native = False
                    op.enable_lossy()
                    op.bit_exact = True
                print(
                    layer.full_name,
                    "bit excatm",
                    layer.bit_exact,
                    "############################################ numeric bits_exact",
                )
            else:
                for op in layer.atomic_ops:
                    op.fully_native = True
                    op.disable_lossy()
                print(layer.full_name, "native")

    def set_quantized_norm_research(self, online_zp_compensation=False, internal_encoding=False):
        # Get environment variables
        TASK_TYPE = os.getenv("TASK_TYPE", default="quantized")
        print("#########################", TASK_TYPE)
        if TASK_TYPE == "quantized":
            self._set_quantized_norm()
            self._set_internal_encoding(internal_encoding)
            self._online_zp_compensation(online_zp_compensation=online_zp_compensation)
        elif TASK_TYPE == "softmax_numeric":
            self._set_quantized_softmax()
            self._model.enable_internal_encoding()
            # self._set_internal_encoding(internal_encoding)
            # self._online_zp_compensation(online_zp_compensation=online_zp_compensation)
        elif TASK_TYPE == "softmax_numeric_no_output_and_act":
            self._set_quantized_softmax_no_output_and_act()
            self._model.enable_internal_encoding()
        elif TASK_TYPE == "softmax_numeric_output_native":
            self._set_quantized_softmax_output_native()
            self._model.enable_internal_encoding()
        elif TASK_TYPE == "softmax_numeric_inverse":
            self._set_quantized_softmax_special_inverse()
            self._model.enable_internal_encoding()

        elif TASK_TYPE == "softmax_numeric_special":
            self._set_quantized_softmax_special()
            self._model.enable_internal_encoding()

        elif TASK_TYPE == "softmax_io_quant":
            self._set_quantized_softmax_io()
            self._model.enable_internal_encoding()
        elif TASK_TYPE == "softmax_output_quant":
            self._set_quantized_softmax_output()
            self._model.enable_internal_encoding()
        elif TASK_TYPE == "softmax_input_quant":
            self._set_quantized_softmax_input()
            self._model.enable_internal_encoding()

        elif TASK_TYPE == "softmax_exact":
            self._set_quantized_softmax_exact()
            self._set_internal_encoding(internal_encoding)
            self._online_zp_compensation(online_zp_compensation=online_zp_compensation)
            self.compile(run_eagerly=True)

        elif TASK_TYPE == "softmax_lossless":
            self._set_lossless_softmax()
            self._model.enable_internal_encoding()
            # self._set_internal_encoding(internal_encoding)
            # self._online_zp_compensation(online_zp_compensation=online_zp_compensation)
        elif TASK_TYPE == "softmax_lossless1":
            self._set_lossless_softmax1()
            self._model.enable_internal_encoding()
        elif TASK_TYPE == "quantized_only_norm_native":
            self._set_quantized_only_norm_native()
            self._set_internal_encoding(internal_encoding)
            self._online_zp_compensation(online_zp_compensation=online_zp_compensation)
        elif TASK_TYPE == "native_only_norm_numeric":
            self._set_native_only_norm_numeric()
        elif TASK_TYPE == "native":
            self.set_native()

    def is_jit_compile_supported(self, training=False):
        """EXPERIMENTAL: should indicate if jit_compile works for the model.
        Will return False if the model has layers that doesn't support jit_compile.

        Args:
            training (bool, optional): if training will be applied (jit_compile gradients). Defaults to False.

        Returns:
            bool: indication if jit_compile works
        """
        return self._model.is_jit_compile_supported(training=training)


class SimulationTrainingModel(SimulationInferenceModel):
    """
    a class to generate and return training model for QAT

    Args:
        model (HailoModel): Acceleras model for training
    """

    def __init__(self, model: HailoModel):
        super().__init__(model)
        self.trainable = True

    def compile(self, *args, **kwargs):
        return self._model.compile(*args, **kwargs)

    def fit(self, *args, **kwargs):
        return self._model.fit(*args, **kwargs)
