Created
July 16, 2018 05:57
-
-
Save dyerrington/ff0f5a2951e33ad179b5174b9222b0f8 to your computer and use it in GitHub Desktop.
As a point of comparison with the default Nietzsche example from the Keras repo, this little experiment swaps out the dataset with forum comments from My Little Pony subreddit.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Example script to generate text from Nietzsche's writings. | |
At least 20 epochs are required before the generated text | |
starts sounding coherent. | |
It is recommended to run this script on GPU, as recurrent | |
networks are quite computationally intensive. | |
If you try this script on new data, make sure your corpus | |
has at least ~100k characters. ~1M is better. | |
''' | |
from __future__ import print_function | |
from keras.callbacks import LambdaCallback | |
from keras.models import Sequential | |
from keras.layers import Dense, Activation | |
from keras.layers import LSTM | |
from keras.optimizers import RMSprop | |
from keras.utils.data_utils import get_file | |
import numpy as np | |
import random | |
import sys | |
import io | |
# path = get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt') | |
# with io.open(path, encoding='utf-8') as f: | |
# text = f.read().lower() | |
# print('corpus length:', len(text)) | |
## Using my little pony dataset instead | |
file = "https://github.com/linanqiu/reddit-dataset/blob/master/television_mylittlepony.csv?raw=true" | |
df = pd.read_csv(file) | |
text = " ".join(df["0"].astype(str).head(20000).tolist()) | |
chars = sorted(list(set(text))) | |
print('total chars:', len(chars)) | |
char_indices = dict((c, i) for i, c in enumerate(chars)) | |
indices_char = dict((i, c) for i, c in enumerate(chars)) | |
# cut the text in semi-redundant sequences of maxlen characters | |
maxlen = 40 | |
step = 3 | |
sentences = [] | |
next_chars = [] | |
for i in range(0, len(text) - maxlen, step): | |
sentences.append(text[i: i + maxlen]) | |
next_chars.append(text[i + maxlen]) | |
print('nb sequences:', len(sentences)) | |
print('Vectorization...') | |
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool) | |
y = np.zeros((len(sentences), len(chars)), dtype=np.bool) | |
for i, sentence in enumerate(sentences): | |
for t, char in enumerate(sentence): | |
x[i, t, char_indices[char]] = 1 | |
y[i, char_indices[next_chars[i]]] = 1 | |
# build the model: a single LSTM | |
print('Build model...') | |
model = Sequential() | |
model.add(LSTM(128, input_shape=(maxlen, len(chars)))) | |
model.add(Dense(len(chars))) | |
model.add(Activation('softmax')) | |
optimizer = RMSprop(lr=0.01) | |
model.compile(loss='categorical_crossentropy', optimizer=optimizer) | |
def sample(preds, temperature=1.0): | |
# helper function to sample an index from a probability array | |
preds = np.asarray(preds).astype('float64') | |
preds = np.log(preds) / temperature | |
exp_preds = np.exp(preds) | |
preds = exp_preds / np.sum(exp_preds) | |
probas = np.random.multinomial(1, preds, 1) | |
return np.argmax(probas) | |
def on_epoch_end(epoch, logs): | |
# Function invoked at end of each epoch. Prints generated text. | |
print() | |
print('----- Generating text after Epoch: %d' % epoch) | |
start_index = random.randint(0, len(text) - maxlen - 1) | |
for diversity in [0.2, 0.5, 1.0, 1.2]: | |
print('----- diversity:', diversity) | |
generated = '' | |
sentence = text[start_index: start_index + maxlen] | |
generated += sentence | |
print('----- Generating with seed: "' + sentence + '"') | |
sys.stdout.write(generated) | |
for i in range(400): | |
x_pred = np.zeros((1, maxlen, len(chars))) | |
for t, char in enumerate(sentence): | |
x_pred[0, t, char_indices[char]] = 1. | |
preds = model.predict(x_pred, verbose=0)[0] | |
next_index = sample(preds, diversity) | |
next_char = indices_char[next_index] | |
generated += next_char | |
sentence = sentence[1:] + next_char | |
sys.stdout.write(next_char) | |
sys.stdout.flush() | |
print() | |
print_callback = LambdaCallback(on_epoch_end=on_epoch_end) | |
model.fit(x, y, | |
batch_size=128, | |
epochs=60, | |
callbacks=[print_callback]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment