Skip to content

Instantly share code, notes, and snippets.

@VXU1230
Last active March 19, 2019 18:08
Show Gist options
  • Save VXU1230/af7654ef7d1499d22847449b5355388d to your computer and use it in GitHub Desktop.
Save VXU1230/af7654ef7d1499d22847449b5355388d to your computer and use it in GitHub Desktop.
Create Data Pipeline
def next_batch(train_input, training=True):
target_data = np.hstack(train_input[:, 0]).astype(np.float32)
context_data = np.hstack(train_input[:, 1]).astype(np.float32)
label_data = np.hstack(train_input[:, 2]).astype(np.float32)
word_size = target_data.size // BATCH_SIZE * BATCH_SIZE
epoch = 1
counter = 0
while True:
t_batch = target_data[counter:counter + BATCH_SIZE]
c_batch = context_data[counter:counter + BATCH_SIZE]
l_batch = label_data[counter:counter + BATCH_SIZE]
counter += BATCH_SIZE
if training:
if counter == word_size:
if epoch < NUM_EPOCH:
print("\n epoch {} training finished".format(epoch))
counter = 0
epoch += 1
else:
print("\n epoch {} training finished".format(epoch))
break
else:
if counter == word_size:
counter = 0
yield t_batch, c_batch, l_batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment