Skip to content

Instantly share code, notes, and snippets.

@daquang
Created August 21, 2018 05:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daquang/f2ffcc8d95e4c092c58a5d3036144185 to your computer and use it in GitHub Desktop.
Save daquang/f2ffcc8d95e4c092c58a5d3036144185 to your computer and use it in GitHub Desktop.
For reproducing keras fit_generator bug
from keras.utils import Sequence
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
import pybedtools as pbt
import numpy as np
import random as rn
import keras
import tensorflow
import os
class DummySequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array(batch_x), np.array(batch_y)
def on_epoch_end(self):
a = pbt.example_bedtool('a.bed')
b = pbt.example_bedtool('b.bed')
c = pbt.BedTool(a.cat(b))
if __name__ == '__main__':
seed = 1337
np.random.seed(seed)
rn.seed(seed)
tensorflow.set_random_seed(seed)
x = np.random.random((100, 3))
y = to_categorical(np.random.random(100) > .5).astype(int)
x2 = np.random.random((100, 3))
y2 = to_categorical(np.random.random(100) > .5).astype(int)
seq = DummySequence(x, y, 10)
seq2 = DummySequence(x2, y2, 10)
import pickle
f = open('temp.pkl','wb')
pickle.dump(seq, f)
model = Sequential()
model.add(Dense(32, input_dim=3))
model.add(Dense(2, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
output_dir = 'temp'
callbacks = [
keras.callbacks.TensorBoard(log_dir=output_dir,
histogram_freq=0, write_graph=True, write_images=False),
keras.callbacks.ModelCheckpoint(os.path.join(output_dir+'/', "weights.h5"),
verbose=0, save_weights_only=False, monitor='val_loss')
]
model.fit_generator(generator=seq, validation_data=seq2, workers=3, use_multiprocessing=True, epochs=5, callbacks=callbacks)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment