Created
July 10, 2015 16:46
-
-
Save ih4cku/f5b97b02d8ede71a9aa7 to your computer and use it in GitHub Desktop.
Create rnnlib nc dataset from a label file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import netcdf_helpers | |
import numpy as np | |
import os | |
import cv2 | |
from os import path | |
import sys | |
import cPickle | |
from glob import glob | |
from digitmodel import DigitNet | |
import gflags | |
FLAGS = gflags.FLAGS | |
gflags.DEFINE_string('mean_std', None, 'pickle file containing mean and std.') | |
def parse_args(): | |
try: | |
args = gflags.FLAGS(sys.argv) | |
except gflags.FlagsError, errmsg: | |
print errmsg | |
print '\nFlags:' | |
print gflags.FLAGS | |
sys.exit(-1) | |
return args[1:] | |
def getRawFeatures(imlist): | |
""" | |
`features` shape: (N, D) | |
""" | |
features = np.asarray([cv2.imread(fn, 0).flatten() for fn in imlist]) | |
return features | |
def getNumericLabels(label_fn): | |
""" | |
label file format: | |
Each line has a directory path containing frame images followed by | |
the labelling of the sequence separated by a space. | |
return: | |
A dict whose key is frame directory name, value is space separated | |
labellings. | |
""" | |
with open(label_fn) as f: | |
labellings = f.read().strip().split('\n') | |
labellings = dict([l.strip().split() for l in labellings]) | |
for k in labellings: | |
labellings[k] = ' '.join(list(labellings[k])) | |
return labellings | |
def CnnFeatureFactory(): | |
net = DigitNet() | |
def getCnnFeatures(imlist): | |
features = net.getBlobs(imlist, ['ip1'])['ip1'] | |
return features | |
return getCnnFeatures | |
class NcCreator: | |
def __init__(self, labelReader, featureReader): | |
self.label_fn = '' | |
self.sequences = [] | |
self.targetStrings = [] | |
self.seqTags = [] | |
self.labels = [] | |
self.getLabels = labelReader | |
self.getFeatures = featureReader | |
@staticmethod | |
def getFrameFiles(d, ext='png'): | |
""" | |
Get frame images from `d`, sort with their filenames. | |
""" | |
imlist = glob(path.join(d, '*.'+ext)) | |
imlist.sort(key=lambda fn: int(path.splitext(path.basename(fn))[0])) | |
return imlist | |
@staticmethod | |
def labellings2labels(labellings): | |
labels = [] | |
for ll in labellings.values(): | |
labels.extend(ll.split()) | |
return sorted(set(labels)) | |
@staticmethod | |
def normalize(inputs): | |
""" | |
normalize inputs to 0 mean and std 1. | |
""" | |
def load(self, label_fn): | |
self.label_fn = label_fn | |
# read labels | |
labellings = self.getLabels(label_fn) | |
self.labels = self.labellings2labels(labellings) | |
# read frame images | |
for d in labellings: | |
print 'Reading', d | |
frame_list = self.getFrameFiles(d) | |
frame_data = self.getFeatures(frame_list) | |
# frame_data = np.asarray([self.getFeatures(fn) for fn in frame_list]) | |
self.sequences.append(frame_data) | |
self.targetStrings.append(labellings[d]) | |
self.seqTags.append(d) | |
def save(self, ncFilename): | |
seqLengths = np.array([seq.shape[0] for seq in self.sequences], dtype='int32') | |
seqDims = seqLengths[:, None] | |
inputs = np.vstack(self.sequences) | |
print '---------------------------------------' | |
print 'Normalizing...' | |
mean_path = path.join(path.dirname(ncFilename), 'mean_std.pickle') | |
if FLAGS.mean_std is None: | |
mean = np.mean(inputs, axis=0) | |
std = np.std(inputs, axis=0) | |
with open(mean_path, 'wb') as f: | |
cPickle.dump((mean, std), f, -1) | |
print 'Mean and std are dumped to "%s".' % mean_path | |
else: | |
print 'Loading mean and std from "%s"' % FLAGS.mean_std | |
with open(FLAGS.mean_std, 'rb') as f: | |
mean, std = cPickle.load(f) | |
inputs = (inputs - mean)/std | |
print 'done.' | |
#create a new .nc file | |
f = netcdf_helpers.NetCDFFile(ncFilename, 'w') | |
#create the dimensions | |
netcdf_helpers.createNcDim(f,'numSeqs',len(seqLengths)) | |
netcdf_helpers.createNcDim(f,'numTimesteps',len(inputs)) | |
netcdf_helpers.createNcDim(f,'inputPattSize',len(inputs[0])) | |
netcdf_helpers.createNcDim(f,'numDims',1) | |
netcdf_helpers.createNcDim(f,'numLabels',len(self.labels)) | |
#create the variables | |
netcdf_helpers.createNcStrings(f,'seqTags',self.seqTags,('numSeqs','maxSeqTagLength'),'sequence tags') | |
netcdf_helpers.createNcStrings(f,'labels',self.labels,('numLabels','maxLabelLength'),'labels') | |
netcdf_helpers.createNcStrings(f,'targetStrings',self.targetStrings,('numSeqs','maxTargStringLength'),'target strings') | |
netcdf_helpers.createNcVar(f,'seqLengths',seqLengths,'i',('numSeqs',),'sequence lengths') | |
netcdf_helpers.createNcVar(f,'seqDims',seqDims,'i',('numSeqs','numDims'),'sequence dimensions') | |
netcdf_helpers.createNcVar(f,'inputs',inputs,'f',('numTimesteps','inputPattSize'),'input patterns') | |
#write the data to disk | |
print 'closing file', ncFilename | |
f.close() | |
if __name__ == '__main__': | |
args = parse_args() | |
if len(args) != 2: | |
print 'Usage: %s label_file nc_file' % sys.argv[0] | |
print gflags.FLAGS | |
sys.exit(-1) | |
ncc = NcCreator(getNumericLabels, CnnFeatureFactory()) | |
ncc.load(args[0]) | |
ncc.save(args[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment