Skip to content

Instantly share code, notes, and snippets.

@ebenolson
Created January 4, 2016 14:15
Show Gist options
  • Save ebenolson/072712792c46aa192797 to your computer and use it in GitHub Desktop.
Save ebenolson/072712792c46aa192797 to your computer and use it in GitHub Desktop.
# Context manager to generate batches in the background via a process pool
# Usage:
#
# def batch(seed):
# .... # generate minibatch
# return minibatch
#
# with BatchGenCM(batch) as bg:
# minibatch = next(bg)
# .... # do something with minibatch
import uuid
import os
import pickle
import hashlib
import numpy as np
from multiprocessing import Process, Queue
class BatchGenCM:
def __init__(self, batch_fn, seed=None, num_workers=8):
self.batch_fn = batch_fn
self.num_workers = num_workers
if seed is None:
seed = np.random.randint(4294967295)
self.seed = str(seed)
self.id = uuid.uuid4()
def __enter__(self):
self.jobq = Queue(maxsize=self.num_workers)
self.doneq = Queue()
self.processes = []
self.current_batch = 0
self.finished_batches = []
def produce():
while True:
n = self.jobq.get()
if n is None:
break
seed = hashlib.md5(self.seed + str(n)).hexdigest()
seed = int(seed, 16) % 4294967295
batch = self.batch_fn(seed)
with open('/run/shm/{}-{}'.format(self.id, n), 'w') as ofile:
pickle.dump(batch, ofile, protocol=pickle.HIGHEST_PROTOCOL)
self.doneq.put(n)
for i in range(self.num_workers):
self.jobq.put(i)
p = Process(target=produce)
self.processes.append(p)
p.start()
return self
def __iter__(self):
return self
def next(self):
n = self.current_batch
while n not in self.finished_batches:
i = self.doneq.get()
self.finished_batches.append(i)
fn = '/run/shm/{}-{}'.format(self.id, n)
batch = pickle.load(open(fn))
os.system('rm {}'.format(fn))
self.jobq.put(n + self.num_workers)
self.current_batch += 1
return batch
def __exit__(self, exc_type, exc_value, traceback):
for _ in range(self.num_workers):
self.jobq.put(None)
for process in self.processes:
process.join()
while not self.doneq.empty():
_ = next(self)
@thenomemac
Copy link

Made changes to add python3 support, feel free to use:

# Modified 2016-06-30 by Josiah Olson to add python3 support
# Context manager to generate batches in the background via a process pool
# Usage:
#
# def batch(seed):
#    .... # generate minibatch
#    return minibatch
#
# with BatchGenCM(batch) as bg:
#    minibatch = next(bg)
#    .... # do something with minibatch

import uuid
import os
import pickle
import hashlib
import numpy as np
from multiprocessing import Process, Queue


class BatchGenCM:
    def __init__(self, batch_fn, seed=None, num_workers=8):
        self.batch_fn = batch_fn
        self.num_workers = num_workers
        if seed is None:
            seed = np.random.randint(4294967295)
        self.seed = str(seed)
        self.id = uuid.uuid4()

    def __enter__(self):
        self.jobq = Queue(maxsize=self.num_workers)
        self.doneq = Queue()
        self.processes = []
        self.current_batch = 0
        self.finished_batches = []

        def produce():
            while True:
                n = self.jobq.get()
                if n is None:
                    break
                seed = hashlib.md5((self.seed + str(n)).encode('utf-8')).hexdigest()
                seed = int(seed, 16) % 4294967295
                batch = self.batch_fn(seed)
                with open('/run/shm/{}-{}'.format(self.id, n), 'wb') as ofile:
                    pickle.dump(batch, ofile, protocol=pickle.HIGHEST_PROTOCOL)
                self.doneq.put(n)

        for i in range(self.num_workers):
            self.jobq.put(i)

            p = Process(target=produce)
            self.processes.append(p)
            p.start()

        return self

    def __iter__(self):
        return self

    def __next__(self):
        n = self.current_batch
        while n not in self.finished_batches:
            i = self.doneq.get()
            self.finished_batches.append(i)

        fn = '/run/shm/{}-{}'.format(self.id, n)
        batch = pickle.load(open(fn, 'rb'))
        os.system('rm {}'.format(fn))

        self.jobq.put(n + self.num_workers)
        self.current_batch += 1
        return batch

    def __exit__(self, exc_type, exc_value, traceback):
        for _ in range(self.num_workers):
            self.jobq.put(None)
        for process in self.processes:
            process.join()
        while not self.doneq.empty():
            _ = self.__next__()

@zhaobozb
Copy link

Hi Ebenolson,

I have the same question as @baumgach, how can batch(seed) iterate all the samples in a round without duplicate choose?

@rslprpr
Copy link

rslprpr commented Oct 6, 2016

For large dataset on hdf5 how do you use this batch_generator without loading your data on memory?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment