Skip to content

Instantly share code, notes, and snippets.

@sampepose
Created June 7, 2017 04:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sampepose/1244694a546ed173b2f38d1bb3e6a433 to your computer and use it in GitHub Desktop.
Save sampepose/1244694a546ed173b2f38d1bb3e6a433 to your computer and use it in GitHub Desktop.
def CrossCorrelation(feature_map_a, feature_map_b, maximum_displacement, striding):
'''
Returns cross correlation given feature maps A and B of size W x H x C.
To save memory, the correlation is only computed for locations within
a [-maximum_displacement, +maximum_displacement] square surrounding x1 in feature map B
for each location x1 in feature map A. Furthermore, striding is applied to limit the number of
locations visited in feature map B within the maximum_displacement neighborhood.
'''
C = feature_map_a.shape.as_list()[3]
D = 2 * maximum_displacement + 1
kernels_from_a = tf.extract_image_patches(feature_map_a,
ksizes=[1, 1, 1, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='SAME')
patches_from_b = tf.extract_image_patches(feature_map_b,
ksizes=[1, D, D, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
def map_across_batch(inputs):
'''
Kernel is H x W x C and patch is H x W x (D * D * C)
'''
kernel, patch = inputs[0], inputs[1]
return tf.map_fn(map_across_height, [kernel, patch], dtype=(tf.float32))
def map_across_height(inputs):
'''
Kernel is W x C and patch is W x (D * D * C)
'''
kernel, patch = inputs[0], inputs[1]
return tf.map_fn(map_across_width, [kernel, patch], dtype=(tf.float32))
def map_across_width(inputs):
'''
Kernel is C and patch is (D * D * C)
'''
kernel, patch = inputs[0], inputs[1]
# Reshape kernel from (C) to (1, 1, C, 1)
k = tf.expand_dims(kernel, 1)
k = tf.expand_dims(k, 0)
k = tf.expand_dims(k, 0)
# Reshape patch from (D * D * C) to (1, D, D, C)
p = tf.reshape(patch, [D, D, C])
p = tf.expand_dims(p, 0)
# Convolve kernel from feature map A with image patch of size D^2 from feature map B using striding
conv = tf.nn.conv2d(p, k, strides=[1, striding, striding, 1], padding='SAME')
return tf.reshape(conv, [tf.size(conv)])
return tf.map_fn(map_across_batch, [kernels_from_a, patches_from_b], dtype=(tf.float32))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment