"""
This module contains utility to use DALI dataset for model's training
DALI utilizes the GPU more efficiently than tensorflow - which enhances the train time

DALI mostly imrpoves the parallelization of the memcpy Host to Device

DALI doesn't support nested tuple shape inherently (and support only a flat tuple),
Therefore this utility offers a way to overcome this inherent flaw of DALI

DALI has 2 flaws -
1. the data readers expects a single npy for each data sample
2. It doesn't support multiple input / outputs dataset for keras' fit function
This module treats this flaw in the following ways:
1. It takes the single npy file with multiple data sample and write an npy for each data
2. It pads all the inputs and stacks them together (to a single tensor)
    so that DALI will be able to process multi input / output data
"""

import os
from typing import List

import numpy as np
import tensorflow as tf


def tf_unpad_input(inputs: tf.Tensor, x_shape: tuple) -> List[tf.Tensor]:
    """
    Unpad DALIDataset during tensorflow calls to support multple inputs/outputs with DALI

    Args:
        inputs: Tensors to unpad
        x_shape: The unpadded shape of the inputs

    Returns:
        List of unpadded tensors

    """
    inputs_tensors = tf.unstack(inputs, axis=1)
    if len(inputs_tensors) == 1:
        inputs_tensors = inputs_tensors[0]
    else:
        inputs_tensors = tf_slice_inputs(inputs_tensors, x_shape)
    return inputs_tensors


def tf_slice_inputs(inputs: tf.Tensor, unpadded_shape: tuple):
    """
    Applies the slice on the inputs to remove the padding from DALIDataset

    Args:
        inputs: Tensors to unpad
        unpadded_shape: The unpadded shape of the inputs

    Returns:
        List of unpadded tensors

    """
    sliced_inputs = []
    for inp, curr_shape in zip(inputs, unpadded_shape):
        begin = np.zeros(len(curr_shape) + 1, dtype=curr_shape.dtype)
        size = [-1, *curr_shape]
        sliced_inp = tf.slice(inp, begin, size)
        sliced_inputs.append(sliced_inp)
    return sliced_inputs


def np_pad_inputs(np_data: List[np.array], padded_shape, stack_axis: int = 0) -> np.array:
    """
    Pads the numpy array input to match the same shapes and stacked together to a single tensor

    Args:
        np_data: list of np array object
        stack_axis: axis to stack the data

    Returns:
        tuple with (a single padded np array, np array with the original unpadded shapes)

    """
    if len(np_data) > 1:
        before_pad = np.zeros(len(padded_shape), dtype=padded_shape.dtype)
        padded_data = list()
        for data_inp in np_data:
            after_pad = padded_shape - data_inp.shape
            pad_shape = np.stack([before_pad, after_pad], axis=1)
            single_padded_data = np.pad(data_inp, pad_shape, mode="constant", constant_values=0)
            padded_data.append(single_padded_data)
        np_data = padded_data
    stacked_data = np.stack(np_data, axis=stack_axis)
    return stacked_data


def tf_pad_outputs(output_data: List[tf.Tensor], stack_axis: int = 0):
    out_shapes = [out.shape[1:].as_list() for out in output_data]
    padded_shape = tf.reduce_max(out_shapes, axis=0)
    if len(output_data) > 1:
        before_pad = tf.zeros(len(padded_shape) + 1, dtype=padded_shape.dtype)
        padded_data = list()
        for data_item in output_data:
            after_pad = padded_shape - data_item.shape[1:]
            after_pad = tf.concat([[0], after_pad], 0)
            pad_shape = tf.stack([before_pad, after_pad], axis=1)
            single_padded_data = tf.pad(data_item, pad_shape, mode="CONSTANT", constant_values=0)
            padded_data.append(single_padded_data)
        output_data = padded_data
    stacked_data = tf.stack(output_data, axis=stack_axis)
    return stacked_data


def padding_info(input_sample: List[np.array], stack_axis: int):
    unpadded_shape = np.array([d.shape for d in input_sample])
    padded_shape = np.max(unpadded_shape, axis=0)
    stacked_shape = list(padded_shape)
    stacked_shape.insert(stack_axis, len(input_sample))
    return stacked_shape, padded_shape, unpadded_shape


def pad_data_generator(cache_list, count, padded_shape):
    """
    Take list of multiple caches (usually multiple inputs for a layer)
    Read matching index items for all the caches,  pads them to the same shape and stack them on
    axis 0.

    Args:
        cache_list: list of cache directories with npz files numbered from 0 to count, 'arr' as key.
        count: number of item to iterate
        padded_shape: desired shape to pad all the items from the cache.

    Returns:
        np array of shape (len(cache_list), *padded_shape)

    """
    for i in range(count):
        np_data = []
        for curr_cache in cache_list:
            fname = f"{i}.npz"
            if isinstance(curr_cache, bytes):
                fname = bytes(fname, "utf-8")
            data_item = np.load(os.path.join(curr_cache, fname))["arr"]
            np_data.append(data_item)
        padded_data = np_pad_inputs(np_data, padded_shape, 0)
        yield padded_data


def generator_output_signature(sample_data, stacked_shape):
    dtypes = {sample.dtype for sample in sample_data}
    if len(dtypes) != 1:
        raise ValueError(f"Non-matching data types of stacked cache files {dtypes}")
    dtype = tf.dtypes.as_dtype(dtypes.pop())
    return tf.TensorSpec(shape=stacked_shape, dtype=dtype)
