Created
July 18, 2015 16:13
-
-
Save BradNeuberg/c47a07cce0230cfc2f5d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _generate_leveldb(self, file_path, pairs, target, single_data): | |
""" | |
Caffe uses the LevelDB format to efficiently load its training and validation data; this method | |
writes paired out faces in an efficient way into this format. | |
""" | |
print "\tGenerating LevelDB file at %s..." % file_path | |
shutil.rmtree(file_path, ignore_errors=True) | |
db = leveldb.LevelDB(file_path) | |
batch = leveldb.WriteBatch() | |
commit_every = 250000 | |
start_time = int(round(time.time() * 1000)) | |
for idx in range(len(pairs)): | |
# Each image pair is a top level key with a keyname like 00059999, in increasing | |
# order starting from 00000000. | |
key = siamese_utils.get_key(idx) | |
# Actually expand our images now, taking the index reference and turning it into real | |
# image pairs; we delay doing this until now for efficiency reasons, as we will probably | |
# have more pairs of images than actual computer memory. | |
image_1 = single_data[pairs[idx][0]] | |
image_2 = single_data[pairs[idx][1]] | |
paired_image = np.concatenate([image_1, image_2]) | |
# Do things like mean normalize, etc. that happen across both testing and validation. | |
paired_image = self._preprocess_data(paired_image) | |
# Each entry in the leveldb is a Caffe protobuffer "Datum" object containing details. | |
datum = Datum() | |
# One channel for each image in the pair. | |
datum.channels = 2 # One channel for each image in the pair. | |
datum.height = constants.HEIGHT | |
datum.width = constants.WIDTH | |
datum.data = paired_image.tobytes() | |
datum.label = target[idx] | |
value = datum.SerializeToString() | |
db.Put(key, value) | |
if idx % commit_every == 0: | |
db.Write(batch, sync=True) | |
del batch | |
batch = leveldb.WriteBatch() | |
end_time = int(round(time.time() * 1000)) | |
total_time = end_time - start_time | |
print "Writing batch, key: %s, time for batch: %d" % (key, total_time) | |
start_time = int(round(time.time() * 1000)) | |
db.Write(batch, sync=True) | |
db.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment