Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Created September 9, 2015 17:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save napsternxg/824c25188627d7266f32 to your computer and use it in GitHub Desktop.
Save napsternxg/824c25188627d7266f32 to your computer and use it in GitHub Desktop.
Haiku Char-RNN
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.preprocessing.sequence import pad_sequences
import numpy as np
import random, sys
'''
Example script to generate text from Nietzsche's writings.
At least 20 epochs are required before the generated text
starts sounding coherent.
It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.
If you try this script on new data, make sure your corpus
has at least ~100k characters. ~1M is better.
'''
path = "haiku_all.txt"
text = open(path).read().lower()
print('corpus length:', len(text))
chars = set(text)
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
# cut the text in semi-redundant sequences of maxlen characters
maxlen = 20
step = 1
sentences = []
next_chars = []
for t in text.splitlines():
for i in range(0, len(t) - maxlen, step):
sentences.append(text[i : i + maxlen])
next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))
print('Vectorization...')
X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
for i, sentence in enumerate(sentences):
for t, char in enumerate(sentence):
X[i, t, char_indices[char]] = 1
y[i, char_indices[next_chars[i]]] = 1
# build the model: 2 stacked LSTM
print('Build model...')
model = Sequential()
model.add(LSTM(len(chars), 512, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(512, 512, return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(512, len(chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# helper function to sample an index from a probability array
def sample(a, temperature=1.0):
a = np.log(a)/temperature
a = np.exp(a)/np.sum(np.exp(a))
return np.argmax(np.random.multinomial(1,a,1))
# train the model, output generated text after each iteration
def generate_from_model(model):
start_index = random.randint(0, len(text) - maxlen - 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:
print()
print('----- diversity:', diversity)
generated = ''
sentence = text[start_index : start_index + maxlen]
generated += sentence
print('----- Generating with seed: "' + sentence + '"')
sys.stdout.write(generated)
tot_lines = 0
tot_chars = 0
while True:
if tot_lines > 3 or tot_chars > 120:
break
x = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(sentence):
x[0, t, char_indices[char]] = 1.
preds = model.predict(x, verbose=0)[0]
next_index = sample(preds, diversity)
next_char = indices_char[next_index]
tot_chars += 1
generated += next_char
if next_char == '\n':
tot_lines += 1
sentence = sentence[1:] + next_char
sys.stdout.write(next_char)
sys.stdout.flush()
print()
history = model.fit(X, y, batch_size=10000, nb_epoch=70)
generate_from_model(model)
('----- diversity:', 0.2)
----- Generating with seed: " cane
the train's wh"
cane
the train's wh bcattat saobnahhsaaelmemm t ealaeeteueuiirrrrethu uriibiboboo tat stttalalllaer tmettummm p ooor rthhuie rrres aa
('----- diversity:', 0.5)
----- Generating with seed: " cane
the train's wh"
cane
the train's wh bcattat saobnchhsaaelmammeeeeallemteueuih r i
ii oi ibibss nuo soo ho ssaabbhheaee ae tehueaalaa irirh tmuieu t
('----- diversity:', 1.0)
----- Generating with seed: " cane
the train's wh"
cane
the train's wh bcataatisaoenahhemaelm e bbhekce tt rri reeeslamipyimmonnnoin
iih rotinna
aaa na oss llneceeeee cerh 9nnnniiiiir
('----- diversity:', 1.2)
----- Generating with seed: " cane
the train's wh"
cane
the train's wh bca tat saohnchhe mulbeeaithealaamttihhbdbih saru me emlsaslsait iinhrrihih u hii t o es sssst sao e e i hi hh hh e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment