|
# Modified by Josiah Olson to add python3 support |
|
# Context manager to generate batches in the background via a process pool |
|
# Returns minibatch of fixed shape as shared memory to avoid de-pickling compute time |
|
# Usage: |
|
# |
|
# def batch(seed, arr): |
|
# .... # generate minibatch |
|
# arr[:] = minibatch.ravel() |
|
# return minibatch.shape() |
|
# |
|
# with BatchGenCM(batch, maxBatchDimTuple) as bg: |
|
# minibatch = next(bg) |
|
# .... # do something with minibatch |
|
|
|
import uuid |
|
# import hashlib |
|
import random |
|
import time |
|
import sharedmem |
|
import numpy as np |
|
from multiprocessing import Process, Queue |
|
|
|
|
|
# ToDo: come up with way to fix order based on generation time and seeds |
|
class BatchGenCM: |
|
def __init__(self, batch_fn, max_dim, seed=None, num_workers=8): |
|
self.batch_fn = batch_fn |
|
self.num_workers = num_workers |
|
if seed is None: |
|
seed = random.randint(0, 4294967295) |
|
self.seed = str(seed) |
|
self.id = uuid.uuid4() |
|
self.max_dim = max_dim |
|
|
|
def __enter__(self): |
|
self.jobq = Queue(maxsize=self.num_workers) |
|
self.doneq = Queue() |
|
self.retq = Queue() |
|
self.processes = [] |
|
self.current_batch = 0 |
|
self.arrList = [sharedmem.empty(np.product(self.max_dim), |
|
dtype=np.float32) |
|
for _ in range(self.num_workers)] |
|
|
|
def produce(): |
|
while True: |
|
i = self.jobq.get() |
|
if i is None: |
|
break |
|
# seed = hashlib.md5((self.seed + str(self.current_batch) |
|
# ).encode('utf-8')).hexdigest() |
|
# seed = int(seed, 16) % 4294967295 |
|
|
|
batch = self.batch_fn(None, self.arrList[i]) |
|
self.retq.put(batch) |
|
|
|
self.doneq.put(i) |
|
|
|
for i in range(self.num_workers): |
|
self.jobq.put(i) |
|
|
|
p = Process(target=produce) |
|
self.processes.append(p) |
|
p.start() |
|
|
|
return self |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
i = self.doneq.get() |
|
|
|
shape = self.retq.get() |
|
|
|
batch = self.arrList[i] |
|
batch = batch.reshape(shape) |
|
|
|
self.jobq.put(i) |
|
|
|
self.current_batch += 1 |
|
return batch |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
while not self.jobq.empty(): |
|
self.jobq.get() |
|
while not self.retq.empty(): |
|
self.retq.get() |
|
while not self.doneq.empty(): |
|
self.doneq.get() |
|
|
|
for process in self.processes: |
|
process.terminate() |
|
process.join() |
|
|
|
|
|
d = {i: np.random.random(1000*1000*10).reshape((1000, 1000, 10)).astype(np.float32) for i in range(2)} |
|
|
|
|
|
def getBatch(seed, arr): |
|
np.random.seed(seed) |
|
randint = np.random.randint(0, 2, 1)[0] |
|
# time.sleep(2) |
|
tmpdata = d[randint] |
|
arr[:] = tmpdata.ravel() |
|
return (1000, 1000, 10) |
|
|
|
out = [] |
|
with BatchGenCM(getBatch, (1000, 1000, 10), seed=333, num_workers=4) as bg: |
|
startTime = time.time() |
|
for counter in range(20): |
|
# time.sleep(1) |
|
startTimeSub = time.time() |
|
minibatch = next(bg) |
|
print('Time to get:', time.time() - startTimeSub) |
|
print('Iter Nbr:', counter) |
|
print('First item:', minibatch[0, 0, 0]) |
|
print('Shape:', minibatch.shape) |
|
out.append(minibatch) |
|
print('Time to run all batches:', time.time() - startTime) |
|
|
|
print('Len output:', len(out)) |