Skip to content

Instantly share code, notes, and snippets.

@hristo-vrigazov
Last active September 1, 2019 07:40
Show Gist options
  • Save hristo-vrigazov/30676fbd769013b34d73444befcba82f to your computer and use it in GitHub Desktop.
Save hristo-vrigazov/30676fbd769013b34d73444befcba82f to your computer and use it in GitHub Desktop.
Linspace external
import tensorflow as tf
def tf_repeat(tensor, repeats, axis):
shape = tf.shape(tensor)
axis_range = tf.range(shape.shape[0].value)
ones = tf.ones_like(axis_range)
repeats_tiled = tf.fill(tf.shape(axis_range), repeats)
axis_tiled = tf.fill(tf.shape(axis_range), axis)
mask = tf.equal(axis_range, axis_tiled)
repetitions_shape = tf.where(mask, repeats_tiled, ones)
return tf.tile(tensor, repetitions_shape)
def tf_linspace(start, stop, num, axis=0):
start = tf.convert_to_tensor(start)
stop = tf.convert_to_tensor(stop)
if num <= 0:
raise ValueError('Num has to be >= 1')
expanded_start = tf.expand_dims(start, axis=axis)
if num == 1:
return expanded_start
expanded_stop = tf.expand_dims(stop, axis=axis)
delta = (expanded_stop - expanded_start) / (num - 1.)
shape = tf.shape(expanded_start)
shape_range = tf.range(shape.shape[0].value)
axis_tiled = tf.fill(tf.shape(shape_range), axis)
num_tiled = tf.fill(tf.shape(shape_range), num - 2)
ones = tf.ones_like(num_tiled)
reshape_target = tf.where(tf.equal(axis_tiled, shape_range), num_tiled, ones)
repeats = tf.where(tf.equal(axis_tiled, shape_range), ones, shape)
num_range = tf.range(1, num - 1, dtype=start.dtype)
range_indices = tf.reshape(num_range, reshape_target)
tiled_range_indices = tf.tile(range_indices, repeats)
start_repeated = tf_repeat(expanded_start, num - 2, axis)
delta_repeated = tf_repeat(delta, num - 2, axis)
res = start_repeated + delta_repeated * tiled_range_indices
return tf.concat((expanded_start, res, expanded_stop), axis=axis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment