Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def shift_bhw1_into_bhwn(images1c, shifts):
"""Shifts images horizontally and back-fills with zeros.
@param images: [batch_size, height, width, channels=1]
@param shifts: [batch_size, n_shifts]
@output [batch_size, height, width, channels=n_shifts]
"""
images = tf.tile(images1c, [1, 1, 1, shifts.shape[1]]) # create n_sample_distances channel copies
left = tf.maximum(0, tf.reduce_max(shifts)) # positive numbers are shifts to the right, for which we need to add zeros on the left
right = -tf.minimum(0, tf.reduce_min(shifts)) # negative numbers are shifts to the left, for which we need to add zeros on the right
left_mask = tf.zeros(shape=(tf.shape(images)[0], tf.shape(images)[1], left, tf.shape(images)[3]))
right_mask = tf.zeros(shape=(tf.shape(images)[0], tf.shape(images)[1], right, tf.shape(images)[3]))
padded_images = tf.concat([left_mask, images, right_mask], axis=2)
apply_shifts_to_channels = lambda p: p[0][:, left-p[1]:left-p[1]+images.shape[2]] # p[0] = pair2d, p[1] = shift # positive shift: left-shift
apply_shifts_for_pair = lambda p: tf.map_fn(apply_shifts_to_channels, (tf.transpose(p[0], perm=[2, 0, 1]), p[1]), dtype=images.dtype) # p[0] = pair3d, p[1] = pair_shifts1d
result = tf.map_fn(
apply_shifts_for_pair,
(padded_images, shifts),
dtype=images.dtype,
)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.