Skip to content

Instantly share code, notes, and snippets.

@aaronpolhamus
Last active March 26, 2016 22:31
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save aaronpolhamus/39c7a71151b8560d02dd to your computer and use it in GitHub Desktop.
Adaptation of VGG-like convnet for custom data from http://keras.io/examples/
import os
import sys
import json
import model_control
from numpy import loadtxt, asarray
from pandas import read_csv
from scipy.ndimage import imread
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.optimizers import SGD
Y_train = loadtxt(model_control.y_train_file, delimiter=',', dtype = int)
train_files = os.listdir(model_control.train_img_path)
train_files = ['%s/%s' % (model_control.train_img_path, x) for x in train_files if 'jpg' in x]
X_train = asarray([imread(x) for x in train_files])
X_train.shape #..(8144, 128, 256) (a numpy array of 8144 128x256 greyscale, i.e. single-channel, images)
Y_train.shape #..(8144,) (A 1-d numpy array of integer class labels)
model = Sequential()
model.add(Convolution2D(32, 5, 5, border_mode='valid', input_shape=(1, 128, 256)))
model.add(Activation('relu'))
model.add(Convolution2D(32, 5, 5))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Convolution2D(64, 5, 5, border_mode='valid'))
model.add(Activation('relu'))
model.add(Convolution2D(64, 5, 5))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd)
model.fit(X_train, Y_train, batch_size=32, nb_epoch=1, verbose=1)
model.save_weights('keras_net_weights.h5')
json_string = model.to_json()
with open('keta_net_structure.json', 'wb') as outfile:
json.dump(json_string, outfile)
y_train_file = '/path/to/keras_ex_data/train_labels.txt'
train_img_path = '/path/to/keras_ex_data'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment