Skip to content

Instantly share code, notes, and snippets.

@hccho2
Created March 30, 2020 23:21
Show Gist options
  • Save hccho2/e35e1e65ece8ad32eb535e7edf9ac188 to your computer and use it in GitHub Desktop.
Save hccho2/e35e1e65ece8ad32eb535e7edf9ac188 to your computer and use it in GitHub Desktop.
batch_to_seq & seq _to_batch
# remove last step
def strip(var, nenvs, nsteps, flat = False):
# var: [ nenvs*(nsteps+1), last_dim]
last_dim = var.get_shape()[-1].value
vars = batch_to_seq(var, nenvs, nsteps + 1, flat) # var: (nenvs,last_dim) ---> list of length nsteps+1
# vars: [(nenvs,last_dim), (nenvs,last_dim), .... ] <----- nsteps+1 길이
return seq_to_batch(vars[:-1],last_dim, flat)
def batch_to_seq(h, nbatch, nsteps, flat=False):
if flat:
h = tf.reshape(h, [nbatch, nsteps])
else:
h = tf.reshape(h, [nbatch, nsteps, -1])
return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
def seq_to_batch(h,last_dim=None, flat = False):
shape = h[0].get_shape().as_list()
if not flat:
assert(len(shape) > 1)
return tf.reshape(tf.concat(axis=1, values=h), [-1, last_dim])
else:
return tf.reshape(tf.stack(values=h, axis=1), [-1])
batch_size = 2
seq_length = 4
dim = 3
######### 3D array test
x0 = np.arange(batch_size*seq_length*dim).reshape(batch_size,seq_length,dim)
x = x0.reshape(-1,dim)
X = tf.convert_to_tensor(x)
Y = strip(X,batch_size,seq_length-1,flat=False)
sess = tf.Session()
sess.run(Y)
######### 2D array test
z = np.arange(batch_size*seq_length) # array([0, 1, 2, 3, 4, 5, 6, 7])
Z = tf.convert_to_tensor(z)
W = batch_to_seq(z,batch_size,seq_length,flat=True) # [array([0, 4]), array([1, 5]), array([2, 6]), array([3, 7])]
Z2 = seq_to_batch(W,last_dim=None,flat=True) # array([0, 1, 2, 3, 4, 5, 6, 7])
sess.run([W,Z2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment