Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hiromu/cce292b0dd17331f475e5c0b72ecc6e6 to your computer and use it in GitHub Desktop.
Save hiromu/cce292b0dd17331f475e5c0b72ecc6e6 to your computer and use it in GitHub Desktop.
From 0bf4196dfa0caba0cee28c7edc751c145c834bc1 Mon Sep 17 00:00:00 2001
From: Hiromu Yakura <hiromu1996@gmail.com>
Date: Tue, 24 May 2016 11:33:33 +0900
Subject: [PATCH] add an option to specify the log directory
---
main.py | 5 +++--
model.py | 5 +++--
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/main.py b/main.py
index ac7aaab..0dabc38 100644
--- a/main.py
+++ b/main.py
@@ -17,6 +17,7 @@ flags.DEFINE_integer("image_size", 108, "The size of image to use (will be cente
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
+flags.DEFINE_string("log_dir", "logs", "Directory name to save the log files for tensorboard [logs]")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
@@ -33,10 +34,10 @@ def main(_):
with tf.Session() as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10,
- dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
+ dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir)
else:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
- dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
+ dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir)
if FLAGS.is_train:
dcgan.train(FLAGS)
diff --git a/model.py b/model.py
index 4dd8e1c..06d32a5 100644
--- a/model.py
+++ b/model.py
@@ -11,7 +11,7 @@ class DCGAN(object):
batch_size=64, sample_size = 64, image_shape=[64, 64, 3],
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
- checkpoint_dir=None):
+ checkpoint_dir=None, log_dir=None):
"""
Args:
@@ -59,6 +59,7 @@ class DCGAN(object):
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
+ self.log_dir = log_dir
self.build_model()
def build_model(self):
@@ -118,7 +119,7 @@ class DCGAN(object):
self.g_sum = tf.merge_summary([self.z_sum, self.d__sum,
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
self.d_sum = tf.merge_summary([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
- self.writer = tf.train.SummaryWriter("./logs", self.sess.graph_def)
+ self.writer = tf.train.SummaryWriter(self.log_dir, self.sess.graph_def)
sample_z = np.random.uniform(-1, 1, size=(self.sample_size , self.z_dim))
sample_files = data[0:self.sample_size]
--
2.5.4 (Apple Git-61)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment