Skip to content

Instantly share code, notes, and snippets.

@escorciav
Last active October 4, 2017 22:51
Show Gist options
  • Save escorciav/08b6cef359d0a875b03003452598620e to your computer and use it in GitHub Desktop.
Save escorciav/08b6cef359d0a875b03003452598620e to your computer and use it in GitHub Desktop.
Feed a TF session with pytorch utils for data loading

TL;DR: using syntactic pytorch data loader abstraction to feed your TF model.

WIP. The code should work. Take a look at the benchmark for a particular use case.

  • CSV file loader. using "image,target" as header.
  • simple transformer
  • loop over the entire dataset
  • make a graph
  • call session

Did you like it?

Gimme a star, cite my research, invite me a beer/coffee 😉

Benchmark

You can see a comparison using SSD with InceptionV2 as base network below. In conclusion, 158x speed up with 8 workers.

Pytorchify feeder

2017-10-04 12:07:50,453 INFO batch_detection_naive.py: Warm-up period is over
2017-10-04 12:07:58,740 INFO batch_detection_naive.py: [125/538025]     Batch 0.073 (0.071)     Data-in 0.001 (0.001)
2017-10-04 12:08:07,611 INFO batch_detection_naive.py: [250/538025]     Batch 0.062 (0.071)     Data-in 0.001 (0.001)
2017-10-04 12:08:16,604 INFO batch_detection_naive.py: [375/538025]     Batch 0.068 (0.071)     Data-in 0.001 (0.001)
2017-10-04 12:08:25,568 INFO batch_detection_naive.py: [500/538025]     Batch 0.071 (0.071)     Data-in 0.001 (0.001)
2017-10-04 12:08:25,568 DEBUG batch_detection_naive.py: Stopping...
2017-10-04 12:08:25,585 INFO batch_detection_naive.py: Batch 0.071      Data-in 0.001

single thread for loop in python

2017-10-04 12:13:27,589 INFO batch_detection_naive.py: Warm-up period is over
2017-10-04 12:13:55,106 INFO batch_detection_naive.py: [125/538025]     Batch 0.067 (0.067)     Data-in 0.139 (0.174)   
2017-10-04 12:14:22,875 INFO batch_detection_naive.py: [250/538025]     Batch 0.091 (0.071)     Data-in 0.157 (0.160)   
2017-10-04 12:14:51,455 INFO batch_detection_naive.py: [375/538025]     Batch 0.069 (0.072)     Data-in 0.155 (0.158)   
2017-10-04 12:15:20,456 INFO batch_detection_naive.py: [500/538025]     Batch 0.073 (0.072)     Data-in 0.160 (0.158)
2017-10-04 12:15:20,456 DEBUG batch_detection_naive.py: Stopping...
2017-10-04 12:15:20,471 INFO batch_detection_naive.py: Batch 0.072      Data-in 0.158 
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
class ImageFromCSV(Dataset):
"""Load images from a CSV list.
Iterator to load (optionally, transform/preprocess) images and their
target recorded in a CSV file.
Args:
filename (str, optional): CSV file with list of images to read.
root (str, optional) : files in filename are a relative path with
respect to the dirname here. It reduces size of CSV but not in
memory.
fields (sequence, optional): sequence with field names associated for
image paths and targets, respectively. If not provided, it uses the
first two fields in the first row.
transform (funct, optional): preprocess function applied on a single
PIL image.
"""
def __init__(self, filename, root='', fields=None, transform=None):
self.root = root
self.filename = filename
self.fields = fields
self.transform = transform
self.imgs = self._make_dataset()
def _make_dataset(self):
with open(self.filename, 'r') as fid:
data = csv.DictReader(fid, fieldnames=self.fields)
if self.fields is None:
self.fields = data.fieldnames
else:
for i in self.fields:
if i not in data:
raise ValueError('Missing {} field in {}'
.format(i, self.filename))
imgs = []
for i, row in enumerate(data):
img_name = row[self.fields[0]]
path = os.path.join(self.root, img_name)
target = 0
if len(self.fields) > 1:
target = row[self.fields[1]]
imgs.append((path, target))
return imgs
def __getitem__(self, index):
path, target = self.imgs[index]
target = int(target)
img = Image.open(path)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgs)
def img_transform(image):
"""Image preprocessing
"""
(im_width, im_height) = image.size
tf_format = (im_height, im_width, 3)
image = np.array(image.getdata()).reshape(tf_format).astype(np.uint8)
return image
def my_collate(batch):
"""Puts each data field into a numpy array with outer dimension batch size
"""
if type(batch[0]).__module__ == 'numpy':
elem = batch[0]
if type(elem).__name__ == 'ndarray':
return np.stack(batch)
elif isinstance(batch[0], int):
return np.array(batch, dtype=np.int)
elif isinstance(batch[0], float):
return np.array(batch, dtype=np.float32)
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [my_collate(samples) for samples in transposed]
else:
raise NotImplementedError('WIP, call 911!')
return
dataset = ImageFromCSV(args.filename, args.root, transform=img_transform)
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers,
collate_fn=my_collate)
for i, (image, target) in enumerate(dataloader):
# TODO: benchmark here
print(type(image), image.shape)
print(type(target), target.shape)
# EDIT ME!
# output = sess.run(out_tensors,
# feed_dict={my_tensor: image})
# TODO: benchmark here
# REMOVE ME!
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment