Skip to content

Instantly share code, notes, and snippets.

@Jiaming-Liu
Created July 5, 2017 19:03
Show Gist options
  • Save Jiaming-Liu/e64330e556be56e1ad5e7fa6b4749a1c to your computer and use it in GitHub Desktop.
Save Jiaming-Liu/e64330e556be56e1ad5e7fa6b4749a1c to your computer and use it in GitHub Desktop.
Keras iterator for Caffe-Style LMDB
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="projectConfiguration" value="Nosetests" />
<option name="PROJECT_TEST_RUNNER" value="Nosetests" />
</component>
</module>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 2.7.6 (/usr/bin/python2.7)" project-jdk-type="Python SDK" />
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/keras_lmdb.iml" filepath="$PROJECT_DIR$/.idea/keras_lmdb.iml" />
</modules>
</component>
</project>
from keras_lmdb_iterator.caffe_lmdb import caffe_lmdb
l = caffe_lmdb.CaffeLMDB('/path/to/ilsvrc12_train_lmdb/', shuffle=False)
for i, (l, d) in enumerate(l):
if i < 10:
print l, d.shape
else:
break
from keras.layers import *
from keras.models import Sequential
from keras_lmdb_iterator import LMDBIterator
model = Sequential(layers=[
Flatten(input_shape=(256, 256, 3)),
Dense(1000)
])
model.compile(optimizer='adam', loss=['sparse_categorical_crossentropy'])
model.fit_generator(
iter(LMDBIterator('/path/to/ilsvrc12_train_lmdb/', batch_size=64)),
steps_per_epoch=10000,
epochs=100)
from keras_lmdb_iterator import LMDBIterator
import caffe_python_env
import caffe
import lmdb
import random
import warnings
class CaffeLMDB:
def __init__(self, path, shuffle=False, subdir=True, readonly=True, lock=False):
self.lmdb_env = lmdb.open(path, readonly=readonly, lock=lock, subdir=subdir)
self.shuffle = shuffle
self.keys = None
if shuffle:
warnings.warn('Shuffle with python lmdb is incredibly slow!')
with self.lmdb_env.begin() as txn:
with txn.cursor() as cur:
self.keys = [key for key, _ in cur]
random.shuffle(self.keys)
def __len__(self):
return int(self.lmdb_env.stat()['entries'])
def __iter__(self):
"""
Iterator of caffe lmdb, returning ( label, data) with integer label
and C*H*W np.ndarray data.
"""
if self.shuffle is not True:
return _sequential_iter(self.lmdb_env)
else:
return _random_iter(self.lmdb_env, self.keys)
def _sequential_iter(lmdb_env):
datum = caffe.proto.caffe_pb2.Datum()
txn = lmdb_env.begin()
cursor = txn.cursor()
for _, value in cursor:
datum.ParseFromString(value)
label = datum.label
data = caffe.io.datum_to_array(datum)
yield label, data
cursor.close()
txn.abort()
def _random_iter(lmdb_env, keys):
datum = caffe.proto.caffe_pb2.Datum()
txn = lmdb_env.begin()
for key in keys:
value = txn.get(key)
datum.ParseFromString(value)
label = datum.label
data = caffe.io.datum_to_array(datum)
yield label, data
txn.abort()
import sys
PYCAFFE_PATH = '/path/to/caffe/python/'
sys.path.append(PYCAFFE_PATH)
import numpy as np
from keras.preprocessing.image import Iterator
from caffe_lmdb import caffe_lmdb
class LMDBIterator(Iterator):
def __init__(self, path, batch_size, shuffle=False, seed=None):
_caffe_lmdb = caffe_lmdb.CaffeLMDB(path, shuffle=shuffle)
self.lmdb_loop_iter = _lmdb_loop_iter(_caffe_lmdb)
self.reset_flag = False
if seed is not None:
raise NotImplementedError
super(LMDBIterator, self).__init__(len(_caffe_lmdb), batch_size, shuffle, seed)
def reset(self):
self.reset_flag = True
def next(self, *args, **kwargs):
data = [None for _ in range(self.batch_size)]
labels = [None for _ in range(self.batch_size)]
for i in range(self.batch_size):
labels[i], data[i] = next(self.lmdb_loop_iter)
data = np.stack(data).transpose([0, 2, 3, 1]) # NCHW => NHWC
labels = np.stack(labels)
return data, labels
def __iter__(self):
return self
def __next__(self, *args, **kwargs):
return self.next(*args, **kwargs)
def _lmdb_loop_iter(_caffe_lmdb):
while 1:
for d in iter(_caffe_lmdb):
yield d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment