Last active
April 19, 2018 04:39
-
-
Save ricsonc/dcab9dc9bbd0cebda9f830798c902afb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from keras.models import Model | |
from keras.layers import Input, LSTM, Dense, SimpleRNN, GRU | |
from keras.optimizers import Adam, RMSprop | |
import random | |
import pandas as pd | |
N = 10000 | |
num_repeats = 5 | |
num_epochs = 10 | |
task_num = 0 | |
if task_num == 1: | |
rnn_units = 8 | |
lens = [5, 20] | |
elif task_num == 0: | |
rnn_units = 1 | |
lens = [5, 100] | |
# sequence length options | |
#lens = [5, 10, 15, 20, 20, 30, 40, 50, 75, 100] | |
#lens = [5, 50, 100] | |
res = {} | |
for (RNN_CELL, key) in list(zip([SimpleRNN, GRU, LSTM], ['srnn', 'gru', 'lstm']))[::-1]: | |
res[key] = {} | |
#print(key, end=': ') | |
print key | |
for seq_len in lens: | |
#print(seq_len, end=',') | |
print seq_len | |
xs = np.zeros((N, seq_len)) | |
ys = np.zeros(N) | |
# construct input data | |
#positive_indexes = np.arange(N / 2) | |
#negative_indexes = np.arange(N / 2, N) | |
if task_num == 0: | |
for i in range(N): | |
if random.random() > 0.5: | |
xs[i, 0] = 1 | |
ys[i] = 1 | |
else: | |
xs[i, 0] = -1 | |
ys[i] = 0 | |
elif task_num == 1: | |
for i in range(N): | |
if random.random() > 0.5: | |
if random.random() > 0.5: | |
xs[i, 0] = 1 | |
xs[i, seq_len/2] = 1 | |
else: | |
xs[i, 0] = -1 | |
xs[i, seq_len/2] = -1 | |
ys[i] = 1 | |
else: | |
if random.random() > 0.5: | |
xs[i, 0] = -1 | |
xs[i, seq_len/2] = 1 | |
else: | |
xs[i, 0] = 1 | |
xs[i, seq_len/2] = -1 | |
ys[i] = 0 | |
else: | |
assert False | |
noise = np.random.normal(loc=0, scale=0.1, size=(N, seq_len)) | |
train_xs = (xs + noise).reshape(N, seq_len, 1) | |
train_ys = ys | |
# repeat each experiments multiple times | |
hists = [] | |
for i in range(num_repeats): | |
inputs = Input(shape=(None, 1), name='input') | |
rnn = RNN_CELL(rnn_units, input_shape=(None, 1), name='rnn')(inputs) | |
out = Dense(2, activation='softmax', name='output')(rnn) | |
model = Model(inputs, out) | |
optimizer = RMSprop(lr = 1E-1, epsilon = 1E-4, clipnorm=1.0) | |
#optimizer = Adam(lr = 1E-1, clipnorm=1.0) | |
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy']) | |
hist = model.fit(train_xs, train_ys, epochs=num_epochs, shuffle=True, validation_split=0.2, batch_size=128, verbose=1) | |
hists.append(hist.history['val_acc'][-1]) | |
res[key][seq_len] = hists | |
#print() | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
ax.plot(pd.DataFrame.from_dict(res['lstm']).mean(), label='lstm') | |
ax.plot(pd.DataFrame.from_dict(res['gru']).mean(), label='gru') | |
ax.plot(pd.DataFrame.from_dict(res['srnn']).mean(), label='srnn') | |
ax.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment