Skip to content

Instantly share code, notes, and snippets.

@bobchennan
Last active March 15, 2017 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bobchennan/d414da80f2f5939e1660898aab42a415 to your computer and use it in GitHub Desktop.
Save bobchennan/d414da80f2f5939e1660898aab42a415 to your computer and use it in GitHub Desktop.
keras sharedmem
from multiprocessing import Pipe, Process, Manager
from time import sleep
import sharedmem
batches = Manager().dict()
def train_process(model, num_batches, save_path, q):
cnt = 0
while True:
x, y = q.recv()
if x is None:
break
model.train_on_batch(batches[x], batches[y])
del batches[x], batches[y]
model.save_weight(save_path)
class Training():
def __init__(self, build_network, post_process, generator, max_limit):
self.model = build_network
self.post = post_process
self.gen = generator
self.maxl = max_limit
def work(self, num_itr):
model = build_network()
c1, c2 = Pipe()
count = 0
for _ in xrange(num_itr):
p = Process(target = train_process, args=(self.gen.size(), 'model.hdf5', c2)
count = -1
for cas in xrange(self.gen.size()):
while len(arrays)>=self.maxl:
sleep(0.05)
x, y = next(self.gen)
retx = sharedmem.zeros(x.shape)
rety = sharedmem.zeros(y.shape)
rety[:] = y[:]
for batch_idx in xrange(x.shape[0]):
retx[i] = self.post(x[batch_idx])
count += 1
arrays[count] = retx
count += 1
arrays[count] = rety
c1.send((count-1, count))
c1.send((None, None))
p.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment