-
-
Save onepagecoding/fbdd54ad937247f40832ac84b1b7b225 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# multivariate mlp example | |
from numpy import array | |
from numpy import hstack | |
from keras.models import Sequential | |
from keras.layers import Dense | |
# split a multivariate sequence into samples | |
def split_sequences(sequences, n_steps): | |
X, y = list(), list() | |
for i in range(len(sequences)): | |
# find the end of this pattern | |
end_ix = i + n_steps | |
# check if we are beyond the dataset | |
if end_ix > len(sequences): | |
break | |
# gather input and output parts of the pattern | |
seq_x, seq_y = sequences[i:end_ix, :-1], sequences[end_ix-1, -1] | |
X.append(seq_x) | |
y.append(seq_y) | |
return array(X), array(y) | |
# define input sequence | |
in_seq1 = array([10, 20, 30, 40, 50, 60, 70, 80, 90]) | |
in_seq2 = array([15, 25, 35, 45, 55, 65, 75, 85, 95]) | |
out_seq = array([in_seq1[i]+in_seq2[i] for i in range(len(in_seq1))]) | |
# convert to [rows, columns] structure | |
in_seq1 = in_seq1.reshape((len(in_seq1), 1)) | |
in_seq2 = in_seq2.reshape((len(in_seq2), 1)) | |
out_seq = out_seq.reshape((len(out_seq), 1)) | |
# horizontally stack columns | |
dataset = hstack((in_seq1, in_seq2, out_seq)) | |
# choose a number of time steps | |
n_steps = 3 | |
# convert into input/output | |
X, y = split_sequences(dataset, n_steps) | |
# flatten input | |
n_input = X.shape[1] * X.shape[2] | |
X = X.reshape((X.shape[0], n_input)) | |
# define model | |
model = Sequential() | |
model.add(Dense(100, activation='relu', input_dim=n_input)) | |
model.add(Dense(1)) | |
model.compile(optimizer='adam', loss='mse') | |
# fit model | |
model.fit(X, y, epochs=2000, verbose=0) | |
# demonstrate prediction | |
x_input = array([[80, 85], [90, 95], [100, 105]]) | |
x_input = x_input.reshape((1, n_input)) | |
yhat = model.predict(x_input, verbose=0) | |
print(yhat) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment