Skip to content

Instantly share code, notes, and snippets.

@thenomemac
Created July 7, 2016 20:14
Show Gist options
  • Save thenomemac/8c1065a510c2939b2458a32520858874 to your computer and use it in GitHub Desktop.
Save thenomemac/8c1065a510c2939b2458a32520858874 to your computer and use it in GitHub Desktop.
batchgen implemented with multiprocessing.Array
# Modified 2016-06-30 by Josiah Olson to add python3 support
# Context manager to generate batches in the background via a process pool
# Usage:
#
# def batch(seed):
# .... # generate minibatch
# return minibatch
#
# with BatchGenCM(batch) as bg:
# minibatch = next(bg)
# .... # do something with minibatch
import uuid
import os
import pickle
import hashlib
import random
import time
import numpy as np
from multiprocessing import Array, Process, Queue
class BatchGenCM:
def __init__(self, batch_fn, seed=None, num_workers=8):
self.batch_fn = batch_fn
self.num_workers = num_workers
if seed is None:
seed = np.random.randint(4294967295)
self.seed = str(seed)
self.id = uuid.uuid4()
def __enter__(self):
self.jobq = Queue(maxsize=self.num_workers)
self.doneq = Queue()
self.retq = Queue()
self.processes = []
self.current_batch = 0
self.arrList = [Array('f', 1000*1000*10) for _ in range(self.num_workers)]
def produce():
while True:
i = self.jobq.get()
if i is None:
break
seed = hashlib.md5((self.seed + str(self.current_batch)).encode('utf-8')).hexdigest()
seed = int(seed, 16) % 4294967295
# batch = self.batch_fn(seed)
batch = self.batch_fn(seed, self.arrList[i])
self.retq.put(batch)
#with open('/run/shm/{}-{}'.format(self.id, n), 'wb') as ofile:
# pickle.dump(batch, ofile, protocol=pickle.HIGHEST_PROTOCOL)
self.doneq.put(i)
for i in range(self.num_workers):
self.jobq.put(i)
p = Process(target=produce)
self.processes.append(p)
p.start()
return self
def __iter__(self):
return self
def __next__(self):
n = self.current_batch
print('n:', n)
print('doneq-pre:', self.doneq.qsize())
#while n not in self.finished_batches:
# i = self.doneq.get()
# self.finished_batches.append(i)
i = self.doneq.get()
print('i:', i)
print('doneq-post:', self.doneq.qsize())
#fn = '/run/shm/{}-{}'.format(self.id, n)
#batch = pickle.load(open(fn, 'rb'))
#os.system('rm {}'.format(fn))
#self.retq.put(batch)
print('retq-pre:', self.retq.qsize())
shape = self.retq.get()
batch = np.frombuffer(self.arrList[i].get_obj(), dtype=np.float32)
batch = batch.copy()
batch = batch.reshape(shape)
print('retq-post:', self.retq.qsize())
print('jobq-pre:', self.jobq.qsize())
self.jobq.put(i)
print('jobq-post:', self.jobq.qsize())
self.current_batch += 1
return batch
def __exit__(self, exc_type, exc_value, traceback):
print('jobq-pre:', self.jobq.qsize())
#for _ in range(self.num_workers):
# self.jobq.put(None)
print('doneq-pre:', self.doneq.qsize())
print('retq-pre:', self.retq.qsize())
while not self.jobq.empty():
_ = self.jobq.get()
while not self.retq.empty():
_ = self.retq.get()
while not self.doneq.empty():
_ = self.doneq.get()
for process in self.processes:
process.terminate()
process.join()
print('jobq-post:', self.jobq.qsize())
print('doneq-post:', self.doneq.qsize())
print('retq-post:', self.retq.qsize())
print('toEnd')
d = {i: np.random.random(1000*1000*10).reshape((1000, 1000, 10)).astype(np.float32) for i in range(2)}
def getBatch(seed, arr):
#np.random.seed(seed)
randint = np.random.randint(0, 2, 1)[0]
# time.sleep(2)
tmpdata = d[randint]
arr[:] = tmpdata.ravel()
return (1000, 1000, 10)
# out = []
with BatchGenCM(getBatch, seed=333, num_workers=4) as bg:
startTime = time.time()
for counter in range(20):
# time.sleep(1)
startTimeSub = time.time()
minibatch = next(bg)
print(time.time() - startTimeSub)
print(counter)
# print(minibatch[0, 0, 0])
print(minibatch.shape)
# out.append(minibatch)
print(time.time() - startTime)
print(len(out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment