Created
April 3, 2017 11:05
-
-
Save GINK03/6a015233b81a6413bc90d974665fb4d3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
counter = 0 | |
def build_model(mode=None, maxlen=None, output_dim=None): | |
print('Build model...') | |
def _scheduler(epoch): | |
global counter | |
counter += 1 | |
rate = learning_rates[counter] | |
#model.lr.set_value(rate) | |
print(model.optimizer.lr) | |
#print(counter, rate) | |
return model.optimizer.lr | |
change_lr = LRS(_scheduler) | |
model = Sequential() | |
model.add(LSTM(128*15, return_sequences=False, input_shape=(maxlen, 512))) | |
model.add(Dropout(0.5)) | |
model.add(Dense(output_dim)) | |
model.add(Activation('softmax')) | |
if mode=="rms": | |
optimizer = RMSprop(lr=0.01) | |
if mode=="adam": | |
optimizer = Adam() | |
model.compile(loss='categorical_crossentropy', optimizer=optimizer) | |
return model, change_lr | |
def dynast(preds, temperature=1.0): | |
preds = np.asarray(preds).astype('float64') | |
preds = np.log(preds) / temperature | |
exp_preds = np.exp(preds) | |
preds = exp_preds / np.sum(exp_preds) | |
return np.argmax(preds) | |
def train(mode): | |
datasets = [] | |
for ni, name in enumerate(glob.glob('./dataset/*')[:20000]): | |
if ni%1000 == 0: | |
print("now loading pkl iter %d"%ni) | |
try: | |
with open(name, 'rb') as f: | |
data = pickle.loads(f.read()) | |
except pickle.UnpicklingError as e: | |
continue | |
datasets.append(data) | |
#datasets = datasets*100 | |
random.shuffle(datasets) | |
max_texts = len(datasets) | |
char_index = pickle.loads(open('./char_index.pkl', 'rb').read()) | |
index_char = {index:char for char, index in char_index.items()} | |
term_vec = pickle.loads(open('./term_vec.pkl', 'rb').read()) | |
X = np.zeros((max_texts, 10, 512), dtype=np.float64) | |
Y = np.zeros((max_texts, len(char_index)), dtype=np.float64) | |
emojis_vec = [] | |
for i, dataset in enumerate(datasets): | |
if i%1000 == 0: | |
print("now mapping to numpy iter %d"%i) | |
x, emoji, emoji_each_vec, y = dataset | |
emoji_vec = np.zeros((256), dtype=np.float64) | |
for e in set(list(emoji)): | |
try: | |
emoji_vec += np.array(term_vec[e]) | |
except KeyError as e: | |
print("error %s"%e) | |
emojis_vec.append( [emoji, emoji_vec] ) | |
Y[i, char_index[y]] = 1. | |
for t, vec in enumerate(x): | |
X[i, t,:256] = vec*emoji_vec | |
X[i, t,256:] = vec + emoji_vec | |
open('emojis_vec.pkl', 'wb').write(pickle.dumps(emojis_vec) ) | |
model, scheduler = build_model(mode=mode, maxlen=10, output_dim=len(char_index)) | |
for iteration in range(1, 10000): | |
print() | |
print('-' * 50) | |
print('Iteration', iteration) | |
model.fit(X, Y, batch_size=64, nb_epoch=1)#, callbacks=[scheduler]) | |
MODEL_NAME = "./models.%s/%09d.model"%(mode, iteration) | |
model.save(MODEL_NAME) | |
if iteration%1==0: | |
for diversity in [1.0]: | |
print() | |
print('----- diversity:', diversity) | |
sent = ["*"]*10 | |
emoji_data = emojis_vec[random.randint(0, len(emojis_vec) - 1)] | |
emoji, _emoji_vec = emoji_data | |
print('----- Generating with seed: "' + emoji +" " + "".join(sent) + '"') | |
sys.stdout.write("".join(sent)) | |
for i in range(200): | |
x = np.zeros((1, 10, 512)) | |
try: | |
for t, char in enumerate(sent): | |
x[0, t, :256] = term_vec[char]*_emoji_vec | |
x[0, t, 256:] = term_vec[char] + _emoji_vec | |
preds = model.predict(x, verbose=0)[0] | |
next_index = dynast(preds, diversity) | |
next_char = index_char[next_index] | |
sent.append(next_char) | |
sent = sent[1:] | |
sys.stdout.write(next_char) | |
sys.stdout.flush() | |
except KeyError as e: | |
break | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment