Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created March 5, 2021 17:48
Show Gist options
  • Save ssnl/5d500b83ea179ca99a2c0981b19b320b to your computer and use it in GitHub Desktop.
Save ssnl/5d500b83ea179ca99a2c0981b19b320b to your computer and use it in GitHub Desktop.
diff --git a/train_cifar10.py b/train_cifar10.py
index b52de81..0592f67 100644
--- a/train_cifar10.py
+++ b/train_cifar10.py
@@ -65,13 +65,16 @@ def build_discriminator(image_size, latent_code_length):
y = Conv2D(1024, (3, 3), padding="same")(y)
y = LeakyReLU()(y)
y = Flatten()(y)
- y = Dense(1,activation="sigmoid")(y)
+ y = Dense(1)(y)
return Model([x, z], [y])
def build_train_step(generator, encoder, discriminator):
- g_optimizer = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9)
- e_optimizer = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9)
- d_optimizer = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9)
+ g_optimizer = Adam(lr=0.00005, beta_1=0.0, beta_2=0.9)
+ e_optimizer = Adam(lr=0.00005, beta_1=0.0, beta_2=0.9)
+ d_optimizer = Adam(lr=0.00005, beta_1=0.0, beta_2=0.9)
+
+ fake_label = tf.constant(1, dtype=tf.float32)
+ real_label = tf.constant(0, dtype=tf.float32)
@tf.function
def train_step(real_image, real_code):
@@ -79,16 +82,39 @@ def build_train_step(generator, encoder, discriminator):
fake_image = generator(real_code)
fake_code = encoder(real_image)
+ fake_code = fake_code / tf.reshape(tf.norm(tf.reshape(fake_code, [fake_code.shape[0], -1]), axis=-1), [-1, 1, 1, 1])
+ #print(fake_code)
+ #1/0
d_inputs = [tf.concat([fake_image, real_image], axis=0),
tf.concat([real_code, fake_code], axis=0)]
d_preds = discriminator(d_inputs)
pred_g, pred_e = tf.split(d_preds,num_or_size_splits=2, axis=0)
- d_loss = tf.reduce_mean(-tf.math.log(pred_g + 1e-8)) + \
- tf.reduce_mean(-tf.math.log(1 - pred_e + 1e-8))
- g_loss = tf.reduce_mean(-tf.math.log(1 - pred_g + 1e-8))
- e_loss = tf.reduce_mean(-tf.math.log(pred_e + 1e-8))
+ d_loss = tf.reduce_mean(
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ logits=pred_g, labels=tf.broadcast_to(fake_label, pred_g.shape),
+ )
+ ) + tf.reduce_mean(
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ logits=pred_e, labels=tf.broadcast_to(real_label, pred_e.shape),
+ )
+ )
+ g_loss = tf.reduce_mean(
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ logits=pred_g, labels=tf.broadcast_to(real_label, pred_g.shape),
+ ),
+ )
+ e_loss = tf.reduce_mean(
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ logits=pred_e, labels=tf.broadcast_to(fake_label, pred_e.shape),
+ ),
+ )
+
+ #d_loss = tf.reduce_mean(-tf.math.log(pred_g + 1e-8)) + \
+ # tf.reduce_mean(-tf.math.log(1 - pred_e + 1e-8))
+ #g_loss = tf.reduce_mean(-tf.math.log(1 - pred_g + 1e-8))
+ #e_loss = tf.reduce_mean(-tf.math.log(pred_e + 1e-8))
d_gradients = tf.gradients(d_loss, discriminator.trainable_variables)
g_gradients = tf.gradients(g_loss, generator.trainable_variables)
@@ -104,7 +130,7 @@ def build_train_step(generator, encoder, discriminator):
def train():
check_point = 1000
- iters = 200 * check_point
+ iters = 1000 * check_point
image_size = (32,32,3)
latent_code_length = (2,2,32)
batch_size = 16
@@ -115,8 +141,10 @@ def train():
x_train = np.reshape(x_train, (-1, )+image_size)
x_train = (x_train.astype("float32") / 255) * 2 - 1
- z_train = np.random.uniform(-1.0, 1.0, (num_of_data, )+latent_code_length).astype("float32")
- z_test = np.random.uniform(-1.0, 1.0, (100, )+latent_code_length).astype("float32")
+ z_train = np.random.randn(num_of_data, *latent_code_length).astype("float32")
+ z_train = z_train / (np.sum(z_train ** 2, axis=(1, 2, 3), keepdims=True) ** 0.5)
+ z_test = np.random.randn(100, *latent_code_length).astype("float32")
+ z_test = z_test / (np.sum(z_test ** 2, axis=(1, 2, 3), keepdims=True) ** 0.5)
# ==================== save x images ====================
image = np.reshape(x_train[:100], (10, 10, 32, 32, 3))
@@ -125,7 +153,7 @@ def train():
image = 255 * (image + 1) / 2
image = np.clip(image, 0, 255)
image = image.astype("uint8")
- Image.fromarray(image, "RGB").save("x.png")
+ Image.fromarray(image, "RGB").save("results/x.png")
# =======================================================
generator = build_generator(image_size, latent_code_length)
@@ -138,19 +166,21 @@ def train():
real_code = z_train[np.random.permutation(num_of_data)[:batch_size]]
d_loss, g_loss, e_loss = train_step(real_images, real_code)
- print("\r[{}/{}] d_loss: {:.4}, g_loss: {:.4}, e_loss: {:.4}".format(i,iters, d_loss, g_loss, e_loss),end="")
+ print("[{}/{}] d_loss: {:.4}, g_loss: {:.4}, e_loss: {:.4}".format(i,iters, d_loss, g_loss, e_loss))
- if (i+1)%check_point == 0:
+ if (i+1)%check_point == 0 and False:
# save G(x) images
- image = generator.predict(encoder.predict(x_train[:100]))
+ code = np.reshape(encoder.predict(x_train[:100]), (100, -1))
+ code = code / ((code ** 2).sum(axis=-1, keepdims=True) ** 0.5)
+ image = generator.predict(np.reshape(code, (100, *latent_code_length)))
image = np.reshape(image, (10, 10, 32, 32, 3))
image = np.transpose(image, (0, 2, 1, 3, 4))
image = np.reshape(image, (10 * 32, 10 * 32, 3))
image = 255 * (image + 1) / 2
image = np.clip(image,0,255)
image = image.astype("uint8")
- Image.fromarray(image, "RGB").save("G_E_x-{}.png".format(i//check_point))
+ Image.fromarray(image, "RGB").save(f"results/G_E_x-{(i // check_point):08d}.png")
# save G(z) images
image = generator.predict(z_test)
@@ -160,7 +190,7 @@ def train():
image = 255 * (image + 1) / 2
image = np.clip(image,0,255)
image = image.astype("uint8")
- Image.fromarray(image, "RGB").save("G_z-{}.png".format(i//check_point))
+ Image.fromarray(image, "RGB").save(f"results/G_z-{(i // check_point):08d}.png")
if __name__ == "__main__":
- train()
\ No newline at end of file
+ train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment