Skip to content

Instantly share code, notes, and snippets.

@TheCherry
Created November 3, 2017 18:53
Show Gist options
  • Save TheCherry/bc71d882991587744b16b2f0e7c648bb to your computer and use it in GitHub Desktop.
Save TheCherry/bc71d882991587744b16b2f0e7c648bb to your computer and use it in GitHub Desktop.
from __future__ import absolute_import, division, print_function
import os
import pickle
from six.moves import urllib
import tflearn
from tflearn.data_utils import *
import random
# path = "shakespeare_input.txt"
char_idx_file = 'numbers.pickle'
#
# if not os.path.isfile(path):
# urllib.request.urlretrieve("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/shakespeare_input.txt", path)
#
maxlen = 24
#
dictionary = None
if os.path.isfile(char_idx_file):
print('Loading previous char_idx')
dictionary = pickle.load(open(char_idx_file, 'rb'))
#
# X, Y, char_idx = \
# textfile_to_semi_redundant_sequences(path, seq_maxlen=maxlen, redun_step=3,
# pre_defined_char_idx=char_idx)
#
def gen_data(start, end):
txt = ""
for i in range(start, end, 2):
txt+="{:05}-{:05}-".format(i,i+1)
return txt
start = 0
end = 9999
txt = gen_data(0, 9999)
# for i in range(0, len(txt) - maxlen, 1):
# print([txt[i: i+ maxlen]])
# print([txt[i+ maxlen]])
# if(i == 11):
# exit()
print("------------------")
print(txt[:100])
print(maxlen)
print("------------------")
# exit()
X, Y, dictionary = string_to_semi_redundant_sequences(txt, seq_maxlen=maxlen, char_idx=dictionary, redun_step=1)
pickle.dump(dictionary, open(char_idx_file,'wb'))
# X=pad_sequences(X, maxlen=maxlen, value=0.)
# Y=pad_sequences(Y, maxlen=maxlen, value=0.)
# Y=pad_sequences(Y)
g = tflearn.input_data([None, maxlen, len(dictionary)])
g = tflearn.lstm(g, 512, return_seq=True)
g = tflearn.dropout(g, 0.5)
g = tflearn.lstm(g, 512, return_seq=True)
g = tflearn.dropout(g, 0.5)
g = tflearn.lstm(g, 512)
g = tflearn.dropout(g, 0.5)
g = tflearn.fully_connected(g, len(dictionary), activation='softmax')
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy',
learning_rate=0.0005)
m = tflearn.SequenceGenerator(g, dictionary=dictionary,
seq_maxlen=maxlen,
clip_gradients=5.0,
checkpoint_path='model_shakespeare')
def txtOut(txt):
txt.split(";")
n_epoch = 3
# m.load("lstm_count.tflearn")
print("")
print("")
print("")
print("")
print("")
print("")
m.load("lstm_count.tflearn")
for i in range(1500):
r = random.randrange(start, (end-maxlen)-5)
print(r, r+maxlen+5)
seed = gen_data(r, r+maxlen+5)
seed = seed[seed.index("-"):seed.index("-")+maxlen]
print(seed)
print(len(seed))
m.fit(X, Y, validation_set=0.1, batch_size=128,
n_epoch=n_epoch, run_id='shakespeare', snapshot_epoch=False)
print("-- TESTING...")
print("-- Test with temperature of 1.0 --")
print(m.generate(600, temperature=1.0, seq_seed=seed))
print("-- Test with temperature of 0.5 --")
print(m.generate(600, temperature=0.5, seq_seed=seed))
m.save("lstm_count.tflearn")
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment