Skip to content

Instantly share code, notes, and snippets.

@bicepjai
Created November 15, 2017 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bicepjai/663c3da21c654618b5a6c8bade265c3f to your computer and use it in GitHub Desktop.
Save bicepjai/663c3da21c654618b5a6c8bade265c3f to your computer and use it in GitHub Desktop.
checking on dc gan implementation and training
def train_dcgan(generator_func, discriminator_func, data_in, model_weights_name_prefix, load_model=False,
epochs=10, batch_size=32, lr_opt=1e-3, lr_d_opt=1e-4, plot_epoch=None):
# get image shape
image_shape = data_in[0,:,:,:].shape
print("image shape:",image_shape)
K.clear_session()
# use generator and discriminator functions and make gan model
gan_input = Input(shape=image_shape)
generator_model = generator_func(image_shape)
discriminator_model = discriminator_func(image_shape)
discriminator_input = generator_model(gan_input)
gan_output = discriminator_model(discriminator_input)
gan_model = Model(inputs=gan_input, outputs=gan_output)
# used for batching
n = data_in.shape[0]
# model compilation
optimizer = Adam(lr=lr_opt) # loss for generator and total gan
d_optimizer = Adam(lr=lr_d_opt) # jsut for discriminator
generator_model.compile(loss='binary_crossentropy', optimizer=optimizer)
discriminator_model.compile(loss='binary_crossentropy', optimizer=d_optimizer)
gan_model.compile(loss='binary_crossentropy', optimizer=optimizer)
# no of batched per epoch
batch_count = data_in.shape[0] // batch_size
noise_batch_shape = tuple([batch_size] + list(data_in.shape[1:]))
# as suggested in the paper section 4 on page 3 "No pre-processing was applied to training
# images besides scaling to the range of the tanh activation function [-1, 1].",
# https://stackoverflow.com/questions/5294955/how-to-scale-down-a-range-of-numbers-with-a-known-min-and-max-value
min_intensity, max_intensity = np.min(data_in), np.max(data_in)
data_in = -1 + 2.0*(data_in - min_intensity)/(max_intensity - min_intensity)
# track losses
losses = {'d':[], 'g':[]}
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# option to load existing weights
if load_model:
try:
generator_model.load_weights(model_weights_name_prefix+"_generator.hdf5")
discriminator_model.load_weights(model_weights_name_prefix+"_discriminator.hdf5")
except IOError as ioe:
print("No weights found with name prefix "+model_weights_name_prefix)
for epoch in range(epochs):
d_losses = []
g_losses = []
progbar = generic_utils.Progbar(n)
for batch_index in range(batch_count):
# noise input from noise prior for the generator
noise_prior_input = np.random.uniform(-1, 1, size=noise_batch_shape)
# getting random images from data_in of size=batch_size
# these are the real images that will be fed to the discriminator
real_image_batch = data_in[np.random.randint(0, data_in.shape[0], size=batch_size)]
# predicted fake images from the generator
generator_predictions = generator_model.predict(noise_prior_input, batch_size=batch_size)
# the discriminator takes in the real images and the generated fake images
X = np.concatenate([generator_predictions, real_image_batch])
# labels (in same order as X) for the discriminator
# with One-sided label smoothing
fake_y = [0]*batch_size
real_y = list(np.ones(batch_size) - np.random.random_sample(batch_size)*0.2)
y_discriminator = fake_y + real_y
# training the discriminator
# discriminator trying to distinguish between real and fake images
discriminator_model.trainable = True
d_loss = discriminator_model.train_on_batch(X, y_discriminator)
d_losses += [d_loss]
# trianing the generator-discriminator stack
# generator trying to fool discriminator by generating real looking images
# train on input noise to non-generated output class
noise_prior_input = np.random.uniform(-1, 1, size=noise_batch_shape)
y_generator = [1]*batch_size
discriminator_model.trainable = False
g_loss = gan_model.train_on_batch(noise_prior_input, y_generator)
g_losses += [g_loss]
epoch_header = "Epoch:%d d_loss" % (epoch)
progbar.add(batch_size, values=[(epoch_header, np.mean(d_losses)), ("g_loss", np.mean(g_losses))])
# save weights every batch
generator_model.save_weights(model_weights_name_prefix+"_generator.hdf5")
discriminator_model.save_weights(model_weights_name_prefix+"_discriminator.hdf5")
# update losses
losses['d'] += d_losses
losses['g'] += g_losses
if plot_epoch is not None and epoch % plot_epoch == 0:
plot_output_nosess(generator_model, model_weights_name_prefix+"_generator.hdf5")
return losses
def get_generator_model(input_shape):
w,h,c = input_shape
model = Sequential()
model.add(Conv2D(256, (4, 4), input_shape=input_shape, padding='same'))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(128, (8, 8), strides=2, padding='same'))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (8, 8), strides=2, padding='same'))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(c, (4, 4), strides=2, padding='same'))
model.add(Activation('tanh'))
return model
def get_discriminator_model(input_shape):
model = Sequential()
model.add(Conv2D(128, (4, 4), strides=(2,2), padding='same', input_shape=input_shape))
model.add(LeakyReLU(0.2))
model.add(Conv2D(64, (8, 8), strides=(2,2), padding='same'))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.2))
model.add(Conv2D(128, (4, 4), strides=(2,2), padding='same'))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(64))
model.add(LeakyReLU(0.2))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment