Skip to content

Instantly share code, notes, and snippets.

@thenomemac
Created July 7, 2016 20:14
Show Gist options
  • Save thenomemac/a42b518fde3d35939a394ccb9cf313cf to your computer and use it in GitHub Desktop.
Save thenomemac/a42b518fde3d35939a394ccb9cf313cf to your computer and use it in GitHub Desktop.
batchgen implemented with sharedmem
# Modified by Josiah Olson to add python3 support
# Context manager to generate batches in the background via a process pool
# Returns minibatch of fixed shape as shared memory to avoid de-pickling compute time
# Usage:
#
# def batch(seed, arr):
# .... # generate minibatch
# arr[:] = minibatch.ravel()
# return minibatch.shape()
#
# with BatchGenCM(batch, maxBatchDimTuple) as bg:
# minibatch = next(bg)
# .... # do something with minibatch
import uuid
# import hashlib
import random
import time
import sharedmem
import numpy as np
from multiprocessing import Process, Queue
# ToDo: come up with way to fix order based on generation time and seeds
class BatchGenCM:
def __init__(self, batch_fn, max_dim, seed=None, num_workers=8):
self.batch_fn = batch_fn
self.num_workers = num_workers
if seed is None:
seed = random.randint(0, 4294967295)
self.seed = str(seed)
self.id = uuid.uuid4()
self.max_dim = max_dim
def __enter__(self):
self.jobq = Queue(maxsize=self.num_workers)
self.doneq = Queue()
self.retq = Queue()
self.processes = []
self.current_batch = 0
self.arrList = [sharedmem.empty(np.product(self.max_dim),
dtype=np.float32)
for _ in range(self.num_workers)]
def produce():
while True:
i = self.jobq.get()
if i is None:
break
# seed = hashlib.md5((self.seed + str(self.current_batch)
# ).encode('utf-8')).hexdigest()
# seed = int(seed, 16) % 4294967295
batch = self.batch_fn(None, self.arrList[i])
self.retq.put(batch)
self.doneq.put(i)
for i in range(self.num_workers):
self.jobq.put(i)
p = Process(target=produce)
self.processes.append(p)
p.start()
return self
def __iter__(self):
return self
def __next__(self):
i = self.doneq.get()
shape = self.retq.get()
batch = self.arrList[i]
batch = batch.reshape(shape)
self.jobq.put(i)
self.current_batch += 1
return batch
def __exit__(self, exc_type, exc_value, traceback):
while not self.jobq.empty():
self.jobq.get()
while not self.retq.empty():
self.retq.get()
while not self.doneq.empty():
self.doneq.get()
for process in self.processes:
process.terminate()
process.join()
d = {i: np.random.random(1000*1000*10).reshape((1000, 1000, 10)).astype(np.float32) for i in range(2)}
def getBatch(seed, arr):
np.random.seed(seed)
randint = np.random.randint(0, 2, 1)[0]
# time.sleep(2)
tmpdata = d[randint]
arr[:] = tmpdata.ravel()
return (1000, 1000, 10)
out = []
with BatchGenCM(getBatch, (1000, 1000, 10), seed=333, num_workers=4) as bg:
startTime = time.time()
for counter in range(20):
# time.sleep(1)
startTimeSub = time.time()
minibatch = next(bg)
print('Time to get:', time.time() - startTimeSub)
print('Iter Nbr:', counter)
print('First item:', minibatch[0, 0, 0])
print('Shape:', minibatch.shape)
out.append(minibatch)
print('Time to run all batches:', time.time() - startTime)
print('Len output:', len(out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment