Skip to content

Instantly share code, notes, and snippets.

@l225li
Last active August 7, 2017 08:10
Show Gist options
  • Save l225li/ed7ff0763d13aad8312dbe2c411fd6c6 to your computer and use it in GitHub Desktop.
Save l225li/ed7ff0763d13aad8312dbe2c411fd6c6 to your computer and use it in GitHub Desktop.
This is helper function to extract convolutional features of images using pre-trained VGG16 model. This makes it possible to do feature extraction of big size of images on a CPU machine.
import numpy as np
import bcolz
import csv
import shutil
def save_array(fname, arr):
"""Helper function to save numpy array arr to fname(.dat)"""
c=bcolz.carray(arr, rootdir=fname, mode='w')
c.flush()
def load_array(fname):
"""Helper function to load numpy arry from file fnmame(.dat)"""
return bcolz.open(fname)[:]
def extract_features(path_in, path_out, model):
"""This function extract features using the given convolutional model
to be used as input features for upper layer models.
Args:
path_in (str): directory path of input images
path_out (str): directory path of output dat files
model (Model): model with convolutional layers to extract features
Returns:
Features extracted and labels. Also will save them as dat files
at ``path_out``
"""
# creates directory `path_out` if not already existed
if not os.path.exists(path_out):
os.makedirs(path_out)
batches = get_batches(path_in, batch_size=1, shuffle=False)
labels = batches.classes
n = batches.samples
save_array(path_out+'labels.dat', labels)
# due to the limited ram space, we will the features of samples one by one
for i in range(n):
features = model.predict(next(batches)[0])
save_array(path_out+'{}.dat'.format(i), features)
# now merge the above files to one array
features = np.empty_like(features)
for i in range(n):
f = load_array(path_out+'{}.dat'.format(i))
features = np.concatenate((features, f), axis=0)
# features[0] is a placeholder of all zeros, we want to remove it
# before saving
features = np.delete(features, 0, 0)
save_array(path_out+'features.dat', features)
# delete the individual dat files
for i in range(n):
shutil.rmtree(path_out+'{}.dat'.format(i))
return features, labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment