Created
September 9, 2015 17:20
-
-
Save napsternxg/824c25188627d7266f32 to your computer and use it in GitHub Desktop.
Haiku Char-RNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Sequential | |
from keras.layers.core import Dense, Activation, Dropout | |
from keras.layers.recurrent import LSTM | |
from keras.preprocessing.sequence import pad_sequences | |
import numpy as np | |
import random, sys | |
''' | |
Example script to generate text from Nietzsche's writings. | |
At least 20 epochs are required before the generated text | |
starts sounding coherent. | |
It is recommended to run this script on GPU, as recurrent | |
networks are quite computationally intensive. | |
If you try this script on new data, make sure your corpus | |
has at least ~100k characters. ~1M is better. | |
''' | |
path = "haiku_all.txt" | |
text = open(path).read().lower() | |
print('corpus length:', len(text)) | |
chars = set(text) | |
print('total chars:', len(chars)) | |
char_indices = dict((c, i) for i, c in enumerate(chars)) | |
indices_char = dict((i, c) for i, c in enumerate(chars)) | |
# cut the text in semi-redundant sequences of maxlen characters | |
maxlen = 20 | |
step = 1 | |
sentences = [] | |
next_chars = [] | |
for t in text.splitlines(): | |
for i in range(0, len(t) - maxlen, step): | |
sentences.append(text[i : i + maxlen]) | |
next_chars.append(text[i + maxlen]) | |
print('nb sequences:', len(sentences)) | |
print('Vectorization...') | |
X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool) | |
y = np.zeros((len(sentences), len(chars)), dtype=np.bool) | |
for i, sentence in enumerate(sentences): | |
for t, char in enumerate(sentence): | |
X[i, t, char_indices[char]] = 1 | |
y[i, char_indices[next_chars[i]]] = 1 | |
# build the model: 2 stacked LSTM | |
print('Build model...') | |
model = Sequential() | |
model.add(LSTM(len(chars), 512, return_sequences=True)) | |
model.add(Dropout(0.2)) | |
model.add(LSTM(512, 512, return_sequences=False)) | |
model.add(Dropout(0.2)) | |
model.add(Dense(512, len(chars))) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') | |
# helper function to sample an index from a probability array | |
def sample(a, temperature=1.0): | |
a = np.log(a)/temperature | |
a = np.exp(a)/np.sum(np.exp(a)) | |
return np.argmax(np.random.multinomial(1,a,1)) | |
# train the model, output generated text after each iteration | |
def generate_from_model(model): | |
start_index = random.randint(0, len(text) - maxlen - 1) | |
for diversity in [0.2, 0.5, 1.0, 1.2]: | |
print() | |
print('----- diversity:', diversity) | |
generated = '' | |
sentence = text[start_index : start_index + maxlen] | |
generated += sentence | |
print('----- Generating with seed: "' + sentence + '"') | |
sys.stdout.write(generated) | |
tot_lines = 0 | |
tot_chars = 0 | |
while True: | |
if tot_lines > 3 or tot_chars > 120: | |
break | |
x = np.zeros((1, maxlen, len(chars))) | |
for t, char in enumerate(sentence): | |
x[0, t, char_indices[char]] = 1. | |
preds = model.predict(x, verbose=0)[0] | |
next_index = sample(preds, diversity) | |
next_char = indices_char[next_index] | |
tot_chars += 1 | |
generated += next_char | |
if next_char == '\n': | |
tot_lines += 1 | |
sentence = sentence[1:] + next_char | |
sys.stdout.write(next_char) | |
sys.stdout.flush() | |
print() | |
history = model.fit(X, y, batch_size=10000, nb_epoch=70) | |
generate_from_model(model) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
('----- diversity:', 0.2) | |
----- Generating with seed: " cane | |
the train's wh" | |
cane | |
the train's wh bcattat saobnahhsaaelmemm t ealaeeteueuiirrrrethu uriibiboboo tat stttalalllaer tmettummm p ooor rthhuie rrres aa | |
('----- diversity:', 0.5) | |
----- Generating with seed: " cane | |
the train's wh" | |
cane | |
the train's wh bcattat saobnchhsaaelmammeeeeallemteueuih r i | |
ii oi ibibss nuo soo ho ssaabbhheaee ae tehueaalaa irirh tmuieu t | |
('----- diversity:', 1.0) | |
----- Generating with seed: " cane | |
the train's wh" | |
cane | |
the train's wh bcataatisaoenahhemaelm e bbhekce tt rri reeeslamipyimmonnnoin | |
iih rotinna | |
aaa na oss llneceeeee cerh 9nnnniiiiir | |
('----- diversity:', 1.2) | |
----- Generating with seed: " cane | |
the train's wh" | |
cane | |
the train's wh bca tat saohnchhe mulbeeaithealaamttihhbdbih saru me emlsaslsait iinhrrihih u hii t o es sssst sao e e i hi hh hh e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment