Skip to content

Instantly share code, notes, and snippets.

@JonnoFTW
Created August 31, 2018 04:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JonnoFTW/f94f8d97e57f6796da83b834ce66aa45 to your computer and use it in GitHub Desktop.
Save JonnoFTW/f94f8d97e57f6796da83b834ce66aa45 to your computer and use it in GitHub Desktop.
Predicting Eye Open/Closed State Using Keras LSTM
import pandas as pd
from keras import Sequential
from keras.layers import LSTM, Dense, Dropout, Activation
import numpy as np
from matplotlib import pyplot
def get_data():
data = pd.read_csv('eeg_data.csv')
return data[data.apply(lambda x: np.abs(x - x.median()) / x.std() < 4).all(axis=1)]
def show(df: pd.DataFrame):
# create a subplot for each time series
fig, plots = pyplot.subplots(len(df.columns), squeeze=False)
for i, name in enumerate(df.columns):
subplot = plots[i][0]
subplot.set_title(name, loc='right', x=1.05)
subplot.plot(df[name])
fig.subplots_adjust(bottom=0.2)
pyplot.show()
def do_model(data):
batch_size = 1
training_split = 0.25
num_fields = 14
y_col = 'eyeDetection'
num_rows = len(data)
xs = data.drop(y_col, axis=1).values.astype(np.float32).reshape(num_rows, 1, num_fields)
ys = data[y_col].values.astype(np.float32)
model = Sequential()
model.add(LSTM(128,
stateful=True,
return_sequences=True,
batch_input_shape=(batch_size, 1, num_fields)))
model.add(Dropout(0.1))
model.add(LSTM(128, stateful=True))
model.add(Dropout(.1))
model.add(Dense(1, activation='sigmoid' ))
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
model.fit(xs, ys, batch_size=batch_size, validation_split=training_split, epochs=5, shuffle=False)
data['Predicted EyeState'] = model.predict(xs, batch_size=batch_size)
if __name__ == "__main__":
filtered = get_data()
do_model(filtered)
show(filtered)
@smasish
Copy link

smasish commented Aug 4, 2021

What's the window size in your model?

@JonnoFTW
Copy link
Author

JonnoFTW commented Aug 5, 2021

I think the window size is 1, since batch shape is (1,1, 14). Ie each batch is 1 point in time with 14 fields.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment