Created
October 6, 2017 08:32
-
-
Save mrjazz/930b616768f423850da3dfb04dfc2b6a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import warnings | |
import math | |
import numpy as np | |
from numpy import newaxis | |
from keras.layers.core import Dense, Activation, Dropout | |
from keras.layers.recurrent import LSTM | |
from keras.models import Sequential | |
import matplotlib.pyplot as plt | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #Hide messy TensorFlow warnings | |
warnings.filterwarnings("ignore") #Hide messy Numpy warnings | |
epochs = 1 | |
seq_len = 50 | |
def load_data(seq_len): | |
data = [math.sin(x/5)+1 for x in range(1000)] | |
sequence_length = seq_len + 1 | |
result = [] | |
for index in range(len(data) - sequence_length): | |
result.append(data[index: index + sequence_length]) | |
result = np.array(result) | |
row = round(0.9 * result.shape[0]) | |
train = result[:int(row), :] | |
np.random.shuffle(train) | |
x_train = train[:, :-1] | |
# y_train = train[:, -1] | |
y_train = np.full((row), 1) | |
x_test = result[int(row):, :-1] | |
y_test = result[int(row):, -1] | |
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1)) | |
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1)) | |
return [x_train, y_train, x_test, y_test] | |
def build_model(layers): | |
model = Sequential() | |
model.add(LSTM( | |
input_shape=(layers[1], layers[0]), | |
output_dim=layers[1], | |
return_sequences=False)) | |
model.add(Dropout(0.2)) | |
model.add(Dense( | |
output_dim=layers[2])) | |
model.add(Activation("linear")) | |
model.compile(loss="mse", optimizer="rmsprop") | |
return model | |
X_train, y_train, X_test, y_test = load_data(seq_len) | |
model = build_model([1, 50, 1]) | |
model.fit( | |
X_train, | |
y_train, | |
batch_size=512, | |
nb_epoch=5, | |
validation_split=0.05) | |
# plt.plot(model.predict(X_test)) | |
# plt.show() | |
print(model.predict(X_test)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment