import tensorflow as tf


def reshape_input_by_windows(inp, input_windows):
    window_size = [
        inp.shape[1] // input_windows[0],
        inp.shape[2] // input_windows[1],
        inp.shape[3] // input_windows[2],
    ]
    inp = tf.reshape(
        inp,
        [
            -1,  # 0
            inp.shape[1] // window_size[0],  # 1
            window_size[0],  # 2
            inp.shape[2] // window_size[1],  # 3
            window_size[1],  # 4
            inp.shape[3] // window_size[2],  # 5
            window_size[2],  # 6
        ],
    )
    inp = tf.transpose(inp, perm=[0, 1, 3, 5, 2, 4, 6])
    return tf.reshape(inp, [-1, inp.shape[-3], inp.shape[-2], inp.shape[-1]])


def reshape_output_by_windows(out, output_windows):
    """
    Reverse the reshaping performed in reshape_input_by_windows.

    Arguments:
    - out: Tensor after processing (shape might be changed).
    - output windows: List of integers indicating the window sizes used in height, width, and features.

    Returns:
    - Reconstructed tensor with batch size combined according to window sizes.
    """
    # Calculate the new shape parameters based on output windows
    window_height, window_width, window_features = output_windows

    # Infer the new shape based on the number of output windows and the reduced dimensions.
    # This matches the shape before flattening in reshape_input_by_windows.
    out = tf.reshape(
        out,
        [
            -1,  # Original batch size before windowing
            window_height,  # Number of height output windows
            window_width,  # Number of width output windows
            window_features,  # Number of feature output windows
            out.shape[1],  # Height within a window
            out.shape[2],  # Width within a window
            out.shape[3],  # Feature within a window
        ],
    )

    # Reverse the transpose operation that was applied to interleave output windows with internal dimensions
    out = tf.transpose(out, perm=[0, 1, 4, 2, 5, 3, 6])  # Swap dimensions back

    # Merge the window dimensions with the internal dimensions
    return tf.reshape(
        out,
        [
            -1,
            window_height * out.shape[2],  # Combine output windows to reconstruct original height
            window_width * out.shape[4],  # Combine output windows to reconstruct original width
            window_features * out.shape[6],  # Combine output windows to reconstruct original features
        ],
    )
