import tensorflow as tf


def custom_gather2d(images, x, y, out_h, out_w):
    indices = tf.keras.ops.stack([tf.keras.ops.cast(x, tf.int32), tf.keras.ops.cast(y, tf.int32)], axis=-1)
    indices = tf.keras.ops.reshape(indices, [-1, 2])
    out_shape = [out_h, out_w, images.shape[-1]]

    def gather_single_image(img):
        gathered = tf.gather_nd(img, indices)
        return tf.reshape(gathered, out_shape)

    return tf.map_fn(
        gather_single_image, images, fn_output_signature=tf.TensorSpec(shape=out_shape, dtype=images.dtype)
    )


def CustomBilinear(inputs, target_shape, align_corners=False, half_pixel_centers=False):
    in_h, in_w = inputs.shape[1:-1]
    out_h, out_w = target_shape
    if align_corners:
        x_scale = (in_w - 1) / (out_w - 1)
        y_scale = (in_h - 1) / (out_h - 1)
    else:
        x_scale = in_w / out_w
        y_scale = in_h / out_h

    if half_pixel_centers:
        x_idx = (tf.range(out_w, dtype=tf.float32) + 0.5) * x_scale - 0.5
        x_idx = tf.keras.ops.maximum(x_idx, 0)
        y_idx = (tf.range(out_h, dtype=tf.float32) + 0.5) * y_scale - 0.5
        y_idx = tf.keras.ops.maximum(y_idx, 0)
    else:
        x_idx = tf.range(out_w, dtype=tf.float32) * x_scale
        y_idx = tf.range(out_h, dtype=tf.float32) * y_scale

    x, y = tf.meshgrid(y_idx, x_idx, indexing="ij")
    x0 = tf.keras.ops.cast(tf.keras.ops.floor(x), tf.float32)
    x1 = tf.keras.ops.cast(tf.keras.ops.minimum(x0 + 1, in_h - 1), tf.float32)
    y0 = tf.keras.ops.cast(tf.keras.ops.floor(y), tf.float32)
    y1 = tf.keras.ops.cast(tf.keras.ops.minimum(y0 + 1, in_w - 1), tf.float32)

    dx = x - x0
    dy = y - y0
    wa = tf.keras.ops.expand_dims((1 - dx) * (1 - dy), axis=-1)
    wb = tf.keras.ops.expand_dims(dx * (1 - dy), axis=-1)
    wc = tf.keras.ops.expand_dims((1 - dx) * dy, axis=-1)
    wd = tf.keras.ops.expand_dims(dx * dy, axis=-1)

    Ia = custom_gather2d(inputs, x0, y0, out_h, out_w)
    Ib = custom_gather2d(inputs, x1, y0, out_h, out_w)
    Ic = custom_gather2d(inputs, x0, y1, out_h, out_w)
    Id = custom_gather2d(inputs, x1, y1, out_h, out_w)

    return wa * Ia + wb * Ib + wc * Ic + wd * Id


def CustomNearest(inputs, target_shape, align_corners=False, half_pixel_centers=False):
    in_h, in_w = inputs.shape[1:-1]
    out_h, out_w = target_shape
    if align_corners and out_h > 1:
        height_scale = tf.keras.ops.cast((in_h - 1) / (out_h - 1), tf.float32)
    else:
        height_scale = tf.keras.ops.cast((in_h / out_h), tf.float32)

    if align_corners and out_w > 1:
        width_scale = tf.keras.ops.cast((in_w - 1) / (out_w - 1), tf.float32)
    else:
        width_scale = tf.keras.ops.cast((in_w / out_w), tf.float32)

    y = tf.keras.ops.arange(out_h, dtype=tf.float32)
    x = tf.keras.ops.arange(out_w, dtype=tf.float32)
    x, y = tf.keras.ops.meshgrid(x, y)

    if half_pixel_centers:
        x = (x + 0.5) * width_scale - 0.5
        y = (y + 0.5) * height_scale - 0.5
    else:
        x = x * width_scale
        y = y * height_scale

    x = tf.keras.ops.cast(tf.keras.ops.round(x), tf.int32)
    x = tf.keras.ops.clip(x, 0, in_w - 1)
    y = tf.keras.ops.cast(tf.keras.ops.round(y), tf.int32)
    y = tf.keras.ops.clip(y, 0, in_h - 1)
    return custom_gather2d(inputs, y, x, out_h, out_w)
