Skip to content

Instantly share code, notes, and snippets.

@ih4cku
Created July 10, 2015 16:46
Show Gist options
  • Save ih4cku/f5b97b02d8ede71a9aa7 to your computer and use it in GitHub Desktop.
Save ih4cku/f5b97b02d8ede71a9aa7 to your computer and use it in GitHub Desktop.
Create rnnlib nc dataset from a label file.
#!/usr/bin/env python
import netcdf_helpers
import numpy as np
import os
import cv2
from os import path
import sys
import cPickle
from glob import glob
from digitmodel import DigitNet
import gflags
FLAGS = gflags.FLAGS
gflags.DEFINE_string('mean_std', None, 'pickle file containing mean and std.')
def parse_args():
try:
args = gflags.FLAGS(sys.argv)
except gflags.FlagsError, errmsg:
print errmsg
print '\nFlags:'
print gflags.FLAGS
sys.exit(-1)
return args[1:]
def getRawFeatures(imlist):
"""
`features` shape: (N, D)
"""
features = np.asarray([cv2.imread(fn, 0).flatten() for fn in imlist])
return features
def getNumericLabels(label_fn):
"""
label file format:
Each line has a directory path containing frame images followed by
the labelling of the sequence separated by a space.
return:
A dict whose key is frame directory name, value is space separated
labellings.
"""
with open(label_fn) as f:
labellings = f.read().strip().split('\n')
labellings = dict([l.strip().split() for l in labellings])
for k in labellings:
labellings[k] = ' '.join(list(labellings[k]))
return labellings
def CnnFeatureFactory():
net = DigitNet()
def getCnnFeatures(imlist):
features = net.getBlobs(imlist, ['ip1'])['ip1']
return features
return getCnnFeatures
class NcCreator:
def __init__(self, labelReader, featureReader):
self.label_fn = ''
self.sequences = []
self.targetStrings = []
self.seqTags = []
self.labels = []
self.getLabels = labelReader
self.getFeatures = featureReader
@staticmethod
def getFrameFiles(d, ext='png'):
"""
Get frame images from `d`, sort with their filenames.
"""
imlist = glob(path.join(d, '*.'+ext))
imlist.sort(key=lambda fn: int(path.splitext(path.basename(fn))[0]))
return imlist
@staticmethod
def labellings2labels(labellings):
labels = []
for ll in labellings.values():
labels.extend(ll.split())
return sorted(set(labels))
@staticmethod
def normalize(inputs):
"""
normalize inputs to 0 mean and std 1.
"""
def load(self, label_fn):
self.label_fn = label_fn
# read labels
labellings = self.getLabels(label_fn)
self.labels = self.labellings2labels(labellings)
# read frame images
for d in labellings:
print 'Reading', d
frame_list = self.getFrameFiles(d)
frame_data = self.getFeatures(frame_list)
# frame_data = np.asarray([self.getFeatures(fn) for fn in frame_list])
self.sequences.append(frame_data)
self.targetStrings.append(labellings[d])
self.seqTags.append(d)
def save(self, ncFilename):
seqLengths = np.array([seq.shape[0] for seq in self.sequences], dtype='int32')
seqDims = seqLengths[:, None]
inputs = np.vstack(self.sequences)
print '---------------------------------------'
print 'Normalizing...'
mean_path = path.join(path.dirname(ncFilename), 'mean_std.pickle')
if FLAGS.mean_std is None:
mean = np.mean(inputs, axis=0)
std = np.std(inputs, axis=0)
with open(mean_path, 'wb') as f:
cPickle.dump((mean, std), f, -1)
print 'Mean and std are dumped to "%s".' % mean_path
else:
print 'Loading mean and std from "%s"' % FLAGS.mean_std
with open(FLAGS.mean_std, 'rb') as f:
mean, std = cPickle.load(f)
inputs = (inputs - mean)/std
print 'done.'
#create a new .nc file
f = netcdf_helpers.NetCDFFile(ncFilename, 'w')
#create the dimensions
netcdf_helpers.createNcDim(f,'numSeqs',len(seqLengths))
netcdf_helpers.createNcDim(f,'numTimesteps',len(inputs))
netcdf_helpers.createNcDim(f,'inputPattSize',len(inputs[0]))
netcdf_helpers.createNcDim(f,'numDims',1)
netcdf_helpers.createNcDim(f,'numLabels',len(self.labels))
#create the variables
netcdf_helpers.createNcStrings(f,'seqTags',self.seqTags,('numSeqs','maxSeqTagLength'),'sequence tags')
netcdf_helpers.createNcStrings(f,'labels',self.labels,('numLabels','maxLabelLength'),'labels')
netcdf_helpers.createNcStrings(f,'targetStrings',self.targetStrings,('numSeqs','maxTargStringLength'),'target strings')
netcdf_helpers.createNcVar(f,'seqLengths',seqLengths,'i',('numSeqs',),'sequence lengths')
netcdf_helpers.createNcVar(f,'seqDims',seqDims,'i',('numSeqs','numDims'),'sequence dimensions')
netcdf_helpers.createNcVar(f,'inputs',inputs,'f',('numTimesteps','inputPattSize'),'input patterns')
#write the data to disk
print 'closing file', ncFilename
f.close()
if __name__ == '__main__':
args = parse_args()
if len(args) != 2:
print 'Usage: %s label_file nc_file' % sys.argv[0]
print gflags.FLAGS
sys.exit(-1)
ncc = NcCreator(getNumericLabels, CnnFeatureFactory())
ncc.load(args[0])
ncc.save(args[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment