Skip to content

Instantly share code, notes, and snippets.

@alik604
Last active December 25, 2019 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alik604/a571b8560f759800b1710204dd1fae06 to your computer and use it in GitHub Desktop.
Save alik604/a571b8560f759800b1710204dd1fae06 to your computer and use it in GitHub Desktop.
EEGNet in Keras
def EEGNet(nb_classes, Chans = 64, Samples = 128,
dropoutRate = 0.5, kernLength = 64, F1 = 8,
D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
""" Keras Implementation of EEGNet
http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta
Inputs:
nb_classes : int, number of classes to classify
Chans, Samples : number of channels and time points in the EEG data
dropoutRate : dropout fraction
kernLength : length of temporal convolution in first layer. We found
that setting this to be half the sampling rate worked
well in practice. For the SMR dataset in particular
since the data was high-passed at 4Hz we used a kernel
length of 32.
F1, F2 : number of temporal filters (F1) and number of pointwise
filters (F2) to learn. Default: F1 = 8, F2 = F1 * D.
D : number of spatial filters to learn within each temporal
convolution. Default: D = 2
dropoutType : Either SpatialDropout2D or Dropout, passed as a string.
"""
from keras.layers import * # like any other bad programmer would do :)
# nb_classes =10
# Chans = 64
# Samples = 128
# dropoutRate = 0.5
# kernLength = 64
# F1 = 8
# D = 2
# F2 = 16
# norm_rate = 0.25
# dropoutType = Dropout
if dropoutType == 'SpatialDropout2D':
dropoutType = SpatialDropout2D
elif dropoutType == 'Dropout':
dropoutType = Dropout
else:
raise ValueError('dropoutType must be one of SpatialDropout2D '
'or Dropout, passed as a string.')
model = Sequential()
model.add(Conv2D(F1, (1,kernLength) , padding = 'same', input_shape = (1,Chans,Samples), use_bias= False))
model.add(BatchNormalization(axis = 1))
model.add(DepthwiseConv2D((Chans, 1), use_bias = False,
depth_multiplier = D,
depthwise_constraint = max_norm(1.)))
model.add(BatchNormalization(axis = 1)
model.add(Activation('elu'))
model.add(AveragePooling2D((1, 4)))
model.add(dropoutType(dropoutRate))
model.add(SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same'))
model.add(BatchNormalization(axis = 1))
model.add(Activation('elu'))
model.add(AveragePooling2D((1, 8)))
model.add(dropoutType(dropoutRate))
model.add(Flatten(name = 'flatten'))
model.add(Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(norm_rate)))
model = Model(inputs=input1, outputs=Activation('softmax', name = 'softmax'))
return Mmodel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment