Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Last active April 12, 2021 22:37
Show Gist options
  • Save mehdidc/bd438f57dd8f50fa39037b28678bc563 to your computer and use it in GitHub Desktop.
Save mehdidc/bd438f57dd8f50fa39037b28678bc563 to your computer and use it in GitHub Desktop.
import os
from imageio import imread
import pandas as pd
import lmdb
from caffe2.proto import caffe2_pb2
# Folder should contain a set of images
image_folder = "flickr30k_images"
# CSV should contain image filenames with corresponding captions
dataframe_path = "flickr30k_images/results.csv"
# Target LMDB path
env = lmdb.open(f"lmdb", map_size=1000000000000)
df = pd.read_csv(dataframe_path, sep="|")
# image_filenames in `image_folder`
image_filenames = df.image_name.values
# corresponding caption for each image
image_captions = df[" comment"].values
CACHE_MAX_SIZE = 1000
cache = {}
idx = 0
for image_filename, caption in zip(image_filenames, image_captions):
# get image raw data
data = open(os.path.join(image_folder, image_filename), "rb").read()
try:
img = imread(data)
except Exception:
# some images cannot be read
print(ex)
continue
try:
height, width, channels = img.shape
except Exception as ex:
# some images have weird shape
print(ex)
continue
try:
# make sure caption is a bytestring
caption = caption.encode()
except Exception as ex:
# some captions are weird integers
print(ex)
continue
# Key is image id
key = f"{idx}".encode("ascii")
# value contain image raw data and caption string
shape = (height, width, channels)
tensor_protos = caffe2_pb2.TensorProtos()
img_tensor = tensor_protos.protos.add()
img_tensor.dims.extend(shape)
img_tensor.data_type = caffe2_pb2.TensorProto.STRING
img_tensor.string_data.append(data)
label_tensor = tensor_protos.protos.add()
label_tensor.data_type = caffe2_pb2.TensorProto.STRING
label_tensor.string_data.append(caption)
value = tensor_protos.SerializeToString()
# Put the mapping in memory, save periodically to LMDB, more efficient
cache[key] = value
print(image_filename)
idx += 1
if idx % CACHE_MAX_SIZE == 0:
with env.begin(write=True) as txn:
for key, value in cache.items():
txn.put(key, value)
cache = {}
# Append the remaining cache items
with env.begin(write=True) as txn:
for key, value in cache.items():
txn.put(key, value)
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment