Created
September 18, 2018 12:58
-
-
Save chmodsss/7c1c95f495feaf40bb254a8309182444 to your computer and use it in GitHub Desktop.
Text generation 2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import pandas as pd | |
import numpy as np | |
sents = [s for s in cdata.split()] | |
vocab = sorted(collections.Counter(sents)) | |
vocab2idx = {v:idx for idx,v in enumerate(vocab)} | |
idx2vocab = {idx:v for idx,v in enumerate(vocab)} | |
seq_len = 10 | |
sequences = [] | |
nextword = [] | |
for idx in range(len(sents) - seq_len): | |
seq_sent = sents[idx : idx + seq_len] | |
nxt_word = sents[idx + seq_len] | |
sequences.append(seq_sent) | |
nextword.append(nxt_word) | |
seq = pd.DataFrame({'sequence':sequences, 'target':nextword}) | |
sequence_arr = np.zeros((len(seq), seq_len, len(vocab)), dtype='bool') | |
target_arr = np.zeros((len(seq), len(vocab)), dtype='bool') | |
for s_idx,x,y in seq.itertuples(index=True): | |
target_arr[s_idx][vocab2idx[y]] = 1 | |
for w_idx,word in enumerate(x): | |
sequence_arr[s_idx][w_idx][vocab2idx[word]] = 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment