Skip to content

Instantly share code, notes, and snippets.

@milani
Created July 30, 2017 05:30
Show Gist options
  • Save milani/68c17ffe35f8c84c4cd362d517dbceb0 to your computer and use it in GitHub Desktop.
Save milani/68c17ffe35f8c84c4cd362d517dbceb0 to your computer and use it in GitHub Desktop.
Tests different repeat-elements implementations in tensorflow for Keras
import tensorflow as tf [8/15095]
import numpy as np
import timeit
def concatenate(tensors,axis=-1):
return tf.concat([x for x in tensors], axis)
def repeat_elements_original(x, rep, axis):
x_shape = x.get_shape().as_list()
# slices along the repeat axis
splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
# repeat each slice the given number of reps
x_rep = [s for s in splits for _ in range(rep)]
return concatenate(x_rep, axis)
def repeat_elements_dynamic(x,rep,axis):
# Repeating
auxiliary_axis = axis + 1
x_shape = tf.shape(x)
x_rep = tf.expand_dims(x, axis=auxiliary_axis)
reps = np.ones(len(x.get_shape()) + 1)
reps[auxiliary_axis] = rep
x_rep = tf.tile(x_rep, reps)
# Merging
reps = np.delete(reps, auxiliary_axis)
reps[axis] = rep
reps = tf.constant(reps, dtype='int32')
x_shape = x_shape * reps
x_rep = tf.reshape(x_rep, x_shape)
# Fix shape representation
x_shape = x.get_shape().as_list()
if(x_shape[axis] is not None):
x_shape[axis] *= rep
x_rep.set_shape(x_shape)
x_rep._keras_shape = tuple(x_shape)
return x_rep
matrix = np.random.rand(5,1000,1000)
rep = 2
axis = 1
with tf.Session() as sess:
tensor = tf.constant(matrix)
tensor_repeated_original = repeat_elements_original(tensor,rep,axis)
tensor_repeated_dynamic = repeat_elements_dynamic(tensor,rep,axis)
start = timeit.default_timer()
for i in range(1000):
tensor_repeated_original.eval()
end = timeit.default_timer()
print("original",end - start)
start = timeit.default_timer()
for i in range(1000):
tensor_repeated_dynamic.eval()
end = timeit.default_timer()
print("dynamic",end - start)
@milani
Copy link
Author

milani commented Jul 30, 2017

Running with tensorflow 1.2.1 on GTX 1070, the result is as follows:

('original', 8.336309909820557)
('dynamic', 7.09973406791687)

Running with CUDA_VISIBLE_DEVICES="" (use CPU) on the same system:

('original', 26.153334856033325)
('dynamic', 60.485647201538086)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment