Skip to content

Instantly share code, notes, and snippets.

@soupault
Created March 13, 2017 08:46
Show Gist options
  • Save soupault/8d32e0fa7c24f869c533eebd4ff75c23 to your computer and use it in GitHub Desktop.
Save soupault/8d32e0fa7c24f869c533eebd4ff75c23 to your computer and use it in GitHub Desktop.
class ThreadedDataGenerator(object):
def __init__(self, *args, **kwargs):
self.gp_args = args
self.gp_kwargs = kwargs
self.batch_retrieved = False
self.batch_ready = False
thread = threading.Thread(target=self.run, args=())
thread.daemon = True
thread.start()
def run(self):
while True:
if not self.batch_retrieved:
self.batch = get_patches(*self.gp_args, **self.gp_kwargs)
self.batch_retrieved = False
self.batch_ready = True
time.sleep(1)
def get_batch(self):
result = self.batch
self.batch_retrieved = True
self.batch_ready = False
return result
# Usage:
train_batch_gen = ThreadedDataGenerator(sat_images, band, amount=2048, ptype='train', aug=True, norm_type='stretch_meanstd')
while not train_batch_gen.batch_ready:
print('Waiting for batch')
time.sleep(5)
x_trn, y_trn, coords_trn = train_batch_gen.get_batch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment