Last active
April 21, 2017 10:11
-
-
Save tegg89/b7b1fea865302b3fc9776220d365b1c0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import os | |
from tensorflow.python.platform import app | |
from tensorflow.python.platform import flags | |
from prediction_input import build_tfrecord_input | |
from prediction_model import construct_model | |
# tf record data location: | |
DATA_DIR = '/push/push_testnovel' | |
# local output directory | |
OUT_DIR = '/tmp/data' | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('data_dir', DATA_DIR, 'directory containing data.') | |
flags.DEFINE_string('output_dir', OUT_DIR, 'directory for model checkpoints.') | |
flags.DEFINE_string('event_log_dir', OUT_DIR, 'directory for writing summary.') | |
flags.DEFINE_integer('sequence_length', 10, | |
'sequence length, including context frames.') | |
flags.DEFINE_integer('context_frames', -1, '# of frames before predictions.') | |
flags.DEFINE_integer('use_state', 1, | |
'Whether or not to give the state+action to the model') | |
flags.DEFINE_string('model', 'CDNA', | |
'model architecture to use - CDNA, DNA, or STP') | |
flags.DEFINE_integer('num_masks', 10, | |
'number of masks, usually 1 for DNA, 10 for CDNA, STN.') | |
flags.DEFINE_float('schedsamp_k', 900.0, | |
'The k hyperparameter for scheduled sampling,' | |
'-1 for no scheduled sampling.') | |
flags.DEFINE_float('train_val_split', 0.95, | |
'The percentage of files to use for the training set,' | |
' vs. the validation set.') | |
flags.DEFINE_integer('batch_size', 16, 'batch size for training') | |
flags.DEFINE_float('learning_rate', 0.001, | |
'the base learning rate of the generator') | |
class Model(object): | |
def __init__(self, | |
images=None, | |
actions=None, | |
states=None, | |
sequence_length=None, | |
reuse_scope=None, | |
prefix=None): | |
if sequence_length is None: | |
sequence_length = FLAGS.sequence_length | |
# Split into timesteps. | |
actions = tf.split(axis=1, num_or_size_splits=int(actions.get_shape()[1]), value=actions) | |
actions = [tf.squeeze(act) for act in actions] | |
states = tf.split(axis=1, num_or_size_splits=int(states.get_shape()[1]), value=states) | |
states = [tf.squeeze(st) for st in states] | |
images = tf.split(axis=1, num_or_size_splits=int(images.get_shape()[1]), value=images) | |
images = [tf.squeeze(img) for img in images] | |
if reuse_scope is None: | |
gen_images, gen_states, gen_mask, gen_mask_lists = construct_model( | |
images, | |
actions, | |
states, | |
iter_num=-1, | |
k=FLAGS.schedsamp_k, | |
use_state=FLAGS.use_state, | |
num_masks=FLAGS.num_masks, | |
cdna=FLAGS.model == 'CDNA', | |
dna=FLAGS.model == 'DNA', | |
stp=FLAGS.model == 'STP', | |
context_frames=FLAGS.context_frames) | |
else: # If it's a validation or test model. | |
with tf.variable_scope(reuse_scope, reuse=True): | |
gen_images, gen_states, gen_mask, gen_mask_lists = construct_model( | |
images, | |
actions, | |
states, | |
iter_num=-1, | |
k=FLAGS.schedsamp_k, | |
use_state=FLAGS.use_state, | |
num_masks=FLAGS.num_masks, | |
cdna=FLAGS.model == 'CDNA', | |
dna=FLAGS.model == 'DNA', | |
stp=FLAGS.model == 'STP', | |
context_frames=FLAGS.context_frames) | |
self.gen_mask = gen_mask | |
self.gen_mask_lists = gen_mask_lists | |
self.gen_images = gen_images | |
def load(sess, saver, checkpoint_dir): | |
print(" [*] Reading checkpoints...") | |
model_dir = FLAGS.model | |
checkpoint_dir = os.path.join(checkpoint_dir, model_dir) | |
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) | |
saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) | |
return True | |
else: | |
return False | |
import moviepy.editor as mpy | |
def npy_to_gif(npy, filename): | |
clip = mpy.ImageSequenceClip(list(npy), fps=10) | |
clip.write_gif(filename) | |
def interval_mapping(image, from_min, from_max, to_min, to_max): | |
from_range = from_max - from_min | |
to_range = to_max - to_min | |
scaled = np.array((image - from_min) / float(from_range), dtype=float) | |
return to_min + (scaled * to_range) | |
def main(unused_argv): | |
print('Constructing models and inputs.') | |
with tf.variable_scope('model', reuse=None) as training_scope: | |
images, actions, states = build_tfrecord_input(training=True, shuffle=False) | |
model = Model(images, actions, states, FLAGS.sequence_length, prefix='train') | |
with tf.variable_scope('val_model', reuse=None): | |
val_images, val_actions, val_states = build_tfrecord_input(training=False, shuffle=False) | |
val_model = Model(val_images, val_actions, val_states, | |
FLAGS.sequence_length, training_scope, prefix='val') | |
print('Constructing saver.') | |
# Make saver. | |
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0) | |
# Make training session. | |
sess = tf.InteractiveSession() | |
summary_writer = tf.summary.FileWriter(FLAGS.event_log_dir, graph=sess.graph, flush_secs=10) | |
# if FLAGS.pretrained_model: | |
load(sess, saver, FLAGS.output_dir) | |
print(' [!] File loaded!') | |
tf.train.start_queue_runners(sess) | |
sess.run(tf.global_variables_initializer()) | |
gen_images = sess.run([model.gen_images]) | |
sample = [] | |
for i in range(len(gen_images[0])): | |
sam = np.asarray(gen_images[0][i]) | |
sample.append(sam) | |
sample = np.asarray(sample) | |
print(sample.shape) | |
# np.save('sample.npy', sample) | |
# print(gen_images[0][0].shape) | |
# load_sample = np.load('sample.npy') | |
# load_sample = interval_mapping(load_sample, 0.0, 1.0, 0, 255).astype('uint8') | |
import imageio | |
for i in xrange(16): | |
imageio.mimsave('sample_' + str(i) + '.gif', sample[:, i]) | |
if __name__ == '__main__': | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment