Skip to content

Instantly share code, notes, and snippets.

@tegg89
Last active April 21, 2017 10:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tegg89/b7b1fea865302b3fc9776220d365b1c0 to your computer and use it in GitHub Desktop.
Save tegg89/b7b1fea865302b3fc9776220d365b1c0 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
import os
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from prediction_input import build_tfrecord_input
from prediction_model import construct_model
# tf record data location:
DATA_DIR = '/push/push_testnovel'
# local output directory
OUT_DIR = '/tmp/data'
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', DATA_DIR, 'directory containing data.')
flags.DEFINE_string('output_dir', OUT_DIR, 'directory for model checkpoints.')
flags.DEFINE_string('event_log_dir', OUT_DIR, 'directory for writing summary.')
flags.DEFINE_integer('sequence_length', 10,
'sequence length, including context frames.')
flags.DEFINE_integer('context_frames', -1, '# of frames before predictions.')
flags.DEFINE_integer('use_state', 1,
'Whether or not to give the state+action to the model')
flags.DEFINE_string('model', 'CDNA',
'model architecture to use - CDNA, DNA, or STP')
flags.DEFINE_integer('num_masks', 10,
'number of masks, usually 1 for DNA, 10 for CDNA, STN.')
flags.DEFINE_float('schedsamp_k', 900.0,
'The k hyperparameter for scheduled sampling,'
'-1 for no scheduled sampling.')
flags.DEFINE_float('train_val_split', 0.95,
'The percentage of files to use for the training set,'
' vs. the validation set.')
flags.DEFINE_integer('batch_size', 16, 'batch size for training')
flags.DEFINE_float('learning_rate', 0.001,
'the base learning rate of the generator')
class Model(object):
def __init__(self,
images=None,
actions=None,
states=None,
sequence_length=None,
reuse_scope=None,
prefix=None):
if sequence_length is None:
sequence_length = FLAGS.sequence_length
# Split into timesteps.
actions = tf.split(axis=1, num_or_size_splits=int(actions.get_shape()[1]), value=actions)
actions = [tf.squeeze(act) for act in actions]
states = tf.split(axis=1, num_or_size_splits=int(states.get_shape()[1]), value=states)
states = [tf.squeeze(st) for st in states]
images = tf.split(axis=1, num_or_size_splits=int(images.get_shape()[1]), value=images)
images = [tf.squeeze(img) for img in images]
if reuse_scope is None:
gen_images, gen_states, gen_mask, gen_mask_lists = construct_model(
images,
actions,
states,
iter_num=-1,
k=FLAGS.schedsamp_k,
use_state=FLAGS.use_state,
num_masks=FLAGS.num_masks,
cdna=FLAGS.model == 'CDNA',
dna=FLAGS.model == 'DNA',
stp=FLAGS.model == 'STP',
context_frames=FLAGS.context_frames)
else: # If it's a validation or test model.
with tf.variable_scope(reuse_scope, reuse=True):
gen_images, gen_states, gen_mask, gen_mask_lists = construct_model(
images,
actions,
states,
iter_num=-1,
k=FLAGS.schedsamp_k,
use_state=FLAGS.use_state,
num_masks=FLAGS.num_masks,
cdna=FLAGS.model == 'CDNA',
dna=FLAGS.model == 'DNA',
stp=FLAGS.model == 'STP',
context_frames=FLAGS.context_frames)
self.gen_mask = gen_mask
self.gen_mask_lists = gen_mask_lists
self.gen_images = gen_images
def load(sess, saver, checkpoint_dir):
print(" [*] Reading checkpoints...")
model_dir = FLAGS.model
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
return True
else:
return False
import moviepy.editor as mpy
def npy_to_gif(npy, filename):
clip = mpy.ImageSequenceClip(list(npy), fps=10)
clip.write_gif(filename)
def interval_mapping(image, from_min, from_max, to_min, to_max):
from_range = from_max - from_min
to_range = to_max - to_min
scaled = np.array((image - from_min) / float(from_range), dtype=float)
return to_min + (scaled * to_range)
def main(unused_argv):
print('Constructing models and inputs.')
with tf.variable_scope('model', reuse=None) as training_scope:
images, actions, states = build_tfrecord_input(training=True, shuffle=False)
model = Model(images, actions, states, FLAGS.sequence_length, prefix='train')
with tf.variable_scope('val_model', reuse=None):
val_images, val_actions, val_states = build_tfrecord_input(training=False, shuffle=False)
val_model = Model(val_images, val_actions, val_states,
FLAGS.sequence_length, training_scope, prefix='val')
print('Constructing saver.')
# Make saver.
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
# Make training session.
sess = tf.InteractiveSession()
summary_writer = tf.summary.FileWriter(FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)
# if FLAGS.pretrained_model:
load(sess, saver, FLAGS.output_dir)
print(' [!] File loaded!')
tf.train.start_queue_runners(sess)
sess.run(tf.global_variables_initializer())
gen_images = sess.run([model.gen_images])
sample = []
for i in range(len(gen_images[0])):
sam = np.asarray(gen_images[0][i])
sample.append(sam)
sample = np.asarray(sample)
print(sample.shape)
# np.save('sample.npy', sample)
# print(gen_images[0][0].shape)
# load_sample = np.load('sample.npy')
# load_sample = interval_mapping(load_sample, 0.0, 1.0, 0, 255).astype('uint8')
import imageio
for i in xrange(16):
imageio.mimsave('sample_' + str(i) + '.gif', sample[:, i])
if __name__ == '__main__':
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment