Created
July 5, 2017 19:03
-
-
Save Jiaming-Liu/e64330e556be56e1ad5e7fa6b4749a1c to your computer and use it in GitHub Desktop.
Keras iterator for Caffe-Style LMDB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<module type="PYTHON_MODULE" version="4"> | |
<component name="NewModuleRootManager"> | |
<content url="file://$MODULE_DIR$" /> | |
<orderEntry type="inheritedJdk" /> | |
<orderEntry type="sourceFolder" forTests="false" /> | |
</component> | |
<component name="TestRunnerService"> | |
<option name="projectConfiguration" value="Nosetests" /> | |
<option name="PROJECT_TEST_RUNNER" value="Nosetests" /> | |
</component> | |
</module> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project version="4"> | |
<component name="ProjectRootManager" version="2" project-jdk-name="Python 2.7.6 (/usr/bin/python2.7)" project-jdk-type="Python SDK" /> | |
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project version="4"> | |
<component name="ProjectModuleManager"> | |
<modules> | |
<module fileurl="file://$PROJECT_DIR$/.idea/keras_lmdb.iml" filepath="$PROJECT_DIR$/.idea/keras_lmdb.iml" /> | |
</modules> | |
</component> | |
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras_lmdb_iterator.caffe_lmdb import caffe_lmdb | |
l = caffe_lmdb.CaffeLMDB('/path/to/ilsvrc12_train_lmdb/', shuffle=False) | |
for i, (l, d) in enumerate(l): | |
if i < 10: | |
print l, d.shape | |
else: | |
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import * | |
from keras.models import Sequential | |
from keras_lmdb_iterator import LMDBIterator | |
model = Sequential(layers=[ | |
Flatten(input_shape=(256, 256, 3)), | |
Dense(1000) | |
]) | |
model.compile(optimizer='adam', loss=['sparse_categorical_crossentropy']) | |
model.fit_generator( | |
iter(LMDBIterator('/path/to/ilsvrc12_train_lmdb/', batch_size=64)), | |
steps_per_epoch=10000, | |
epochs=100) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras_lmdb_iterator import LMDBIterator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from caffe_lmdb import CaffeLMDB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import caffe_python_env | |
import caffe | |
import lmdb | |
import random | |
import warnings | |
class CaffeLMDB: | |
def __init__(self, path, shuffle=False, subdir=True, readonly=True, lock=False): | |
self.lmdb_env = lmdb.open(path, readonly=readonly, lock=lock, subdir=subdir) | |
self.shuffle = shuffle | |
self.keys = None | |
if shuffle: | |
warnings.warn('Shuffle with python lmdb is incredibly slow!') | |
with self.lmdb_env.begin() as txn: | |
with txn.cursor() as cur: | |
self.keys = [key for key, _ in cur] | |
random.shuffle(self.keys) | |
def __len__(self): | |
return int(self.lmdb_env.stat()['entries']) | |
def __iter__(self): | |
""" | |
Iterator of caffe lmdb, returning ( label, data) with integer label | |
and C*H*W np.ndarray data. | |
""" | |
if self.shuffle is not True: | |
return _sequential_iter(self.lmdb_env) | |
else: | |
return _random_iter(self.lmdb_env, self.keys) | |
def _sequential_iter(lmdb_env): | |
datum = caffe.proto.caffe_pb2.Datum() | |
txn = lmdb_env.begin() | |
cursor = txn.cursor() | |
for _, value in cursor: | |
datum.ParseFromString(value) | |
label = datum.label | |
data = caffe.io.datum_to_array(datum) | |
yield label, data | |
cursor.close() | |
txn.abort() | |
def _random_iter(lmdb_env, keys): | |
datum = caffe.proto.caffe_pb2.Datum() | |
txn = lmdb_env.begin() | |
for key in keys: | |
value = txn.get(key) | |
datum.ParseFromString(value) | |
label = datum.label | |
data = caffe.io.datum_to_array(datum) | |
yield label, data | |
txn.abort() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
PYCAFFE_PATH = '/path/to/caffe/python/' | |
sys.path.append(PYCAFFE_PATH) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.preprocessing.image import Iterator | |
from caffe_lmdb import caffe_lmdb | |
class LMDBIterator(Iterator): | |
def __init__(self, path, batch_size, shuffle=False, seed=None): | |
_caffe_lmdb = caffe_lmdb.CaffeLMDB(path, shuffle=shuffle) | |
self.lmdb_loop_iter = _lmdb_loop_iter(_caffe_lmdb) | |
self.reset_flag = False | |
if seed is not None: | |
raise NotImplementedError | |
super(LMDBIterator, self).__init__(len(_caffe_lmdb), batch_size, shuffle, seed) | |
def reset(self): | |
self.reset_flag = True | |
def next(self, *args, **kwargs): | |
data = [None for _ in range(self.batch_size)] | |
labels = [None for _ in range(self.batch_size)] | |
for i in range(self.batch_size): | |
labels[i], data[i] = next(self.lmdb_loop_iter) | |
data = np.stack(data).transpose([0, 2, 3, 1]) # NCHW => NHWC | |
labels = np.stack(labels) | |
return data, labels | |
def __iter__(self): | |
return self | |
def __next__(self, *args, **kwargs): | |
return self.next(*args, **kwargs) | |
def _lmdb_loop_iter(_caffe_lmdb): | |
while 1: | |
for d in iter(_caffe_lmdb): | |
yield d |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment