Skip to content

Instantly share code, notes, and snippets.

@tegg89
Created April 21, 2017 07:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tegg89/c22095497864022875c77c1a1cad70ad to your computer and use it in GitHub Desktop.
Save tegg89/c22095497864022875c77c1a1cad70ad to your computer and use it in GitHub Desktop.
video-prediction
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for building the input for the prediction model."""
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
FLAGS = flags.FLAGS
# Original image dimensions
ORIGINAL_WIDTH = 640
ORIGINAL_HEIGHT = 512
COLOR_CHAN = 3
# Default image dimensions.
IMG_WIDTH = 64
IMG_HEIGHT = 64
# Dimension of the state and action.
STATE_DIM = 5
def build_tfrecord_input(training=True, shuffle=False):
"""Create input tfrecord tensors.
Args:
training: training or validation data.
Returns:
list of tensors corresponding to images, actions, and states. The images
tensor is 5D, batch x time x height x width x channels. The state and
action tensors are 3D, batch x time x dimension.
Raises:
RuntimeError: if no files found.
"""
filenames = gfile.Glob(os.path.join(FLAGS.data_dir, '*'))
if not filenames:
raise RuntimeError('No data files found.')
index = int(np.floor(FLAGS.train_val_split * len(filenames)))
if training:
filenames = filenames[:index]
else:
filenames = filenames[index:]
filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
image_seq, state_seq, action_seq = [], [], []
for i in range(FLAGS.sequence_length):
image_name = 'move/' + str(i) + '/image/encoded'
action_name = 'move/' + str(i) + '/commanded_pose/vec_pitch_yaw'
state_name = 'move/' + str(i) + '/endeffector/vec_pitch_yaw'
if FLAGS.use_state:
features = {image_name: tf.FixedLenFeature([1], tf.string),
action_name: tf.FixedLenFeature([STATE_DIM], tf.float32),
state_name: tf.FixedLenFeature([STATE_DIM], tf.float32)}
else:
features = {image_name: tf.FixedLenFeature([1], tf.string)}
features = tf.parse_single_example(serialized_example, features=features)
image_buffer = tf.reshape(features[image_name], shape=[])
image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN)
image.set_shape([ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN])
if IMG_HEIGHT != IMG_WIDTH:
raise ValueError('Unequal height and width unsupported')
crop_size = min(ORIGINAL_HEIGHT, ORIGINAL_WIDTH)
image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size)
image = tf.reshape(image, [1, crop_size, crop_size, COLOR_CHAN])
image = tf.image.resize_bicubic(image, [IMG_HEIGHT, IMG_WIDTH])
image = tf.cast(image, tf.float32) / 255.0
image_seq.append(image)
if FLAGS.use_state:
state = tf.reshape(features[state_name], shape=[1, STATE_DIM])
state_seq.append(state)
action = tf.reshape(features[action_name], shape=[1, STATE_DIM])
action_seq.append(action)
image_seq = tf.concat(axis=0, values=image_seq)
if FLAGS.use_state:
state_seq = tf.concat(axis=0, values=state_seq)
action_seq = tf.concat(axis=0, values=action_seq)
[image_batch, action_batch, state_batch] = tf.train.batch(
[image_seq, action_seq, state_seq],
FLAGS.batch_size,
num_threads=1,
capacity=1)
# test_image_batch, test_action_batch, test_state_batch = tf.train.batch(
# [image_seq, action_seq, state_seq],
# FLAGS.batch_size,
# num_threads=1,
# capacity=1)
return image_batch, action_batch, state_batch
else:
image_batch = tf.train.batch(
[image_seq],
FLAGS.batch_size,
num_threads=1,
capacity=1)
# test_image_batch = tf.train.batch(
# [image_seq],
# FLAGS.batch_size,
# num_threads=1,
# capacity=1)
zeros_batch = tf.zeros([FLAGS.batch_size, FLAGS.sequence_length, STATE_DIM])
return image_batch, zeros_batch, zeros_batch
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model architecture for predictive model, including CDNA, DNA, and STP."""
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.layers.python import layers as tf_layers
from lstm_ops import basic_conv_lstm_cell
# Amount to use when lower bounding tensors
RELU_SHIFT = 1e-12
# kernel size for DNA and CDNA.
DNA_KERN_SIZE = 5
def construct_model(images,
actions=None,
states=None,
iter_num=-1.0,
k=-1,
use_state=True,
num_masks=10,
stp=False,
cdna=True,
dna=False,
context_frames=-1):
"""Build convolutional lstm video predictor using STP, CDNA, or DNA.
Args:
images: tensor of ground truth image sequences
actions: tensor of action sequences
states: tensor of ground truth state sequences
iter_num: tensor of the current training iteration (for sched. sampling)
k: constant used for scheduled sampling. -1 to feed in own prediction.
use_state: True to include state and action in prediction
num_masks: the number of different pixel motion predictions (and
the number of masks for each of those predictions)
stp: True to use Spatial Transformer Predictor (STP)
cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
dna: True to use Dynamic Neural Advection (DNA)
context_frames: number of ground truth frames to pass in before
feeding in own predictions
Returns:
gen_images: predicted future image frames
gen_states: predicted future states
Raises:
ValueError: if more than one network option specified or more than 1 mask
specified for DNA model.
"""
if stp + cdna + dna != 1:
raise ValueError('More than one, or no network option specified.')
batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4]
lstm_func = basic_conv_lstm_cell
# Generated robot states and images.
gen_states, gen_images = [], []
current_state = states[0]
# if k == -1:
# feedself = True
# else:
# Scheduled sampling:
# Calculate number of ground-truth frames to pass in.
# num_ground_truth = tf.to_int32(
# tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
# feedself = False
# LSTM state sizes and states.
lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
lstm_state5, lstm_state6, lstm_state7 = None, None, None
for image, action in zip(images[:-1], actions[:-1]):
# Reuse variables after the first timestep.
reuse = bool(gen_images)
# done_warm_start = len(gen_images) > context_frames - 1
with slim.arg_scope(
[lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
tf_layers.layer_norm, slim.layers.conv2d_transpose],
reuse=reuse):
# if feedself and done_warm_start:
# Feed in generated image.
# prev_image = gen_images[-1]
# elif done_warm_start:
# Scheduled sampling
# prev_image = scheduled_sample(image, gen_images[-1], batch_size,
# num_ground_truth)
# else:
# Always feed in ground_truth
prev_image = image
# Predicted state is always fed back in
#state_action = tf.concat(axis=1, values=[action, current_state])
state_action = tf.concat(values=[action, current_state], axis=1)
enc0 = slim.layers.conv2d(
prev_image,
32, [5, 5],
stride=2,
scope='scale1_conv1',
normalizer_fn=tf_layers.layer_norm,
normalizer_params={'scope': 'layer_norm1'})
hidden1, lstm_state1 = lstm_func(enc0, lstm_state1, lstm_size[0], scope='state1')
hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
hidden2, lstm_state2 = lstm_func(hidden1, lstm_state2, lstm_size[1], scope='state2')
hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
enc1 = slim.layers.conv2d(hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')
hidden3, lstm_state3 = lstm_func(enc1, lstm_state3, lstm_size[2], scope='state3')
hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
hidden4, lstm_state4 = lstm_func(hidden3, lstm_state4, lstm_size[3], scope='state4')
hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
enc2 = slim.layers.conv2d(hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')
# Pass in state and action.
smear = tf.reshape(
state_action,
[int(batch_size), 1, 1, int(state_action.get_shape()[1])])
smear = tf.tile(
smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
if use_state:
enc2 = tf.concat(axis=3, values=[enc2, smear])
enc3 = slim.layers.conv2d(enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')
hidden5, lstm_state5 = lstm_func(enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8
hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
enc4 = slim.layers.conv2d_transpose(hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
hidden6, lstm_state6 = lstm_func(enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16
hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
# Skip connection.
hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16
enc5 = slim.layers.conv2d_transpose(hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
hidden7, lstm_state7 = lstm_func(enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32
hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
# Skip connection.
hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32
enc6 = slim.layers.conv2d_transpose(
hidden7,
hidden7.get_shape()[3], 3, stride=2, scope='convt3',
normalizer_fn=tf_layers.layer_norm,
normalizer_params={'scope': 'layer_norm9'})
if dna:
# Using largest hidden state for predicting untied conv kernels.
enc7 = slim.layers.conv2d_transpose(enc6, DNA_KERN_SIZE**2, 1, stride=1, scope='convt4')
else:
# Using largest hidden state for predicting a new image layer.
enc7 = slim.layers.conv2d_transpose(enc6, color_channels, 1, stride=1, scope='convt4')
# This allows the network to also generate one image from scratch,
# which is useful when regions of the image become unoccluded.
transformed = [tf.nn.sigmoid(enc7)]
if stp:
stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
stp_input1 = slim.layers.fully_connected(stp_input0, 100, scope='fc_stp')
transformed += stp_transformation(prev_image, stp_input1, num_masks)
elif cdna:
cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
transformed += cdna_transformation(prev_image, cdna_input, num_masks,
int(color_channels))
elif dna:
# Only one mask is supported (more should be unnecessary).
if num_masks != 1:
raise ValueError('Only one mask is supported for DNA model.')
transformed = [dna_transformation(prev_image, enc7)]
masks = slim.layers.conv2d_transpose(enc6, num_masks + 1, 1, stride=1, scope='convt7')
masks = tf.reshape(tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
[int(batch_size), int(img_height), int(img_width), num_masks + 1])
#mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks)
mask_list = tf.split(masks, num_masks + 1, 3)
output = mask_list[0] * prev_image
for layer, mask in zip(transformed, mask_list[1:]):
output += layer * mask
gen_images.append(output)
current_state = slim.layers.fully_connected(
state_action,
int(current_state.get_shape()[1]),
scope='state_pred',
activation_fn=None)
gen_states.append(current_state)
return gen_images, gen_states, masks, mask_list
## Utility functions
def stp_transformation(prev_image, stp_input, num_masks):
"""Apply spatial transformer predictor (STP) to previous image.
Args:
prev_image: previous image to be transformed.
stp_input: hidden layer to be used for computing STN parameters.
num_masks: number of masks and hence the number of STP transformations.
Returns:
List of images transformed by the predicted STP parameters.
"""
# Only import spatial transformer if needed.
from spatial_transformer import transformer
identity_params = tf.convert_to_tensor(
np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
transformed = []
for i in range(num_masks - 1):
params = slim.layers.fully_connected(
stp_input, 6, scope='stp_params' + str(i),
activation_fn=None) + identity_params
transformed.append(transformer(prev_image, params))
return transformed
def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
"""Apply convolutional dynamic neural advection to previous image.
Args:
prev_image: previous image to be transformed.
cdna_input: hidden layer to be used for computing CDNA kernels.
num_masks: the number of masks and hence the number of CDNA transformations.
color_channels: the number of color channels in the images.
Returns:
List of images transformed by the predicted CDNA kernels.
"""
batch_size = int(cdna_input.get_shape()[0])
# Predict kernels using linear function of last hidden layer.
cdna_kerns = slim.layers.fully_connected(
cdna_input,
DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
scope='cdna_params',
activation_fn=None)
# Reshape and normalize.
cdna_kerns = tf.reshape(
cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
cdna_kerns /= norm_factor
cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
#cdna_kerns = tf.split(axis=0, num_or_size_splits=batch_size, value=cdna_kerns)
cdna_kerns = tf.split(cdna_kerns, batch_size, )
#prev_images = tf.split(axis=0, num_or_size_splits=batch_size, value=prev_image)
prev_images = tf.split(prev_image, batch_size, 0)
# Transform image.
transformed = []
for kernel, preimg in zip(cdna_kerns, prev_images):
kernel = tf.squeeze(kernel)
if len(kernel.get_shape()) == 3:
kernel = tf.expand_dims(kernel, -1)
transformed.append(
tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
transformed = tf.concat(axis=0, values=transformed)
#transformed = tf.split(axis=3, num_or_size_splits=num_masks, value=transformed)
transformed = tf.split(transformed, num_masks, 3)
return transformed
def dna_transformation(prev_image, dna_input):
"""Apply dynamic neural advection to previous image.
Args:
prev_image: previous image to be transformed.
dna_input: hidden lyaer to be used for computing DNA transformation.
Returns:
List of images transformed by the predicted CDNA kernels.
"""
# Construct translated images.
prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
image_height = int(prev_image.get_shape()[1])
image_width = int(prev_image.get_shape()[2])
inputs = []
for xkern in range(DNA_KERN_SIZE):
for ykern in range(DNA_KERN_SIZE):
inputs.append(
tf.expand_dims(
tf.slice(prev_image_pad, [0, xkern, ykern, 0],
[-1, image_height, image_width, -1]), [3]))
inputs = tf.concat(axis=3, values=inputs)
# Normalize channels to 1.
kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
kernel = tf.expand_dims(
kernel / tf.reduce_sum(
kernel, [3], keep_dims=True), [4])
return tf.reduce_sum(kernel * inputs, [3], keep_dims=False)
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
"""Sample batch with specified mix of ground truth and generated data points.
Args:
ground_truth_x: tensor of ground-truth data points.
generated_x: tensor of generated data points.
batch_size: batch size
num_ground_truth: number of ground-truth examples to include in batch.
Returns:
New batch with num_ground_truth sampled from ground_truth_x and the rest
from generated_x.
"""
idx = tf.random_shuffle(tf.range(int(batch_size)))
ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
generated_examps = tf.gather(generated_x, generated_idx)
return tf.dynamic_stitch([ground_truth_idx, generated_idx],
[ground_truth_examps, generated_examps])
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for training the prediction model."""
import numpy as np
import tensorflow as tf
import os
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from prediction_input import build_tfrecord_input
from prediction_model import construct_model
# How often to record tensorboard summaries.
SUMMARY_INTERVAL = 40
# How often to run a batch through the validation model.
VAL_INTERVAL = 100
# How often to save a model checkpoint
SAVE_INTERVAL = 50
# TEST_INTERVAL = 1000
# tf record data location:
DATA_DIR = '/push/push_train'
# local output directory
OUT_DIR = '/tmp/data'
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', DATA_DIR, 'directory containing data.')
flags.DEFINE_string('output_dir', OUT_DIR, 'directory for model checkpoints.')
flags.DEFINE_string('event_log_dir', OUT_DIR, 'directory for writing summary.')
flags.DEFINE_integer('num_iterations', 1000, 'number of training iterations.')
flags.DEFINE_string('pretrained_model', '',
'filepath of a pretrained model to initialize from.')
flags.DEFINE_integer('sequence_length', 10,
'sequence length, including context frames.')
flags.DEFINE_integer('context_frames', -1, '# of frames before predictions.')
flags.DEFINE_integer('use_state', 1,
'Whether or not to give the state+action to the model')
flags.DEFINE_string('model', 'CDNA',
'model architecture to use - CDNA, DNA, or STP')
flags.DEFINE_integer('num_masks', 10,
'number of masks, usually 1 for DNA, 10 for CDNA, STN.')
flags.DEFINE_float('schedsamp_k', 900.0,
'The k hyperparameter for scheduled sampling,'
'-1 for no scheduled sampling.')
flags.DEFINE_float('train_val_split', 0.95,
'The percentage of files to use for the training set,'
' vs. the validation set.')
flags.DEFINE_integer('batch_size', 16, 'batch size for training')
flags.DEFINE_float('learning_rate', 0.001,
'the base learning rate of the generator')
## Helper functions
def peak_signal_to_noise_ratio(true, pred):
"""Image quality metric based on maximal signal power vs. power of the noise.
Args:
true: the ground truth image.
pred: the predicted image.
Returns:
peak signal to noise ratio (PSNR)
"""
return 10.0 * tf.log(1.0 / mean_squared_error(true, pred)) / tf.log(10.0)
def mean_squared_error(true, pred):
"""L2 distance between tensors true and pred.
Args:
true: the ground truth image.
pred: the predicted image.
Returns:
mean squared error between ground truth and predicted image.
"""
return tf.reduce_sum(tf.square(true - pred)) / tf.to_float(tf.size(pred))
class Model(object):
def __init__(self,
images=None,
actions=None,
states=None,
sequence_length=None,
reuse_scope=None,
prefix=None):
if sequence_length is None:
sequence_length = FLAGS.sequence_length
#self.prefix = prefix = tf.placeholder(tf.string, [])
if prefix is None:
prefix = tf.placeholder(tf.string, [])
self.prefix = prefix
self.iter_num = tf.placeholder(tf.float32, [])
summaries = []
# Split into timesteps.
actions = tf.split(axis=1, num_or_size_splits=int(actions.get_shape()[1]), value=actions)
actions = [tf.squeeze(act) for act in actions]
states = tf.split(axis=1, num_or_size_splits=int(states.get_shape()[1]), value=states)
states = [tf.squeeze(st) for st in states]
images = tf.split(axis=1, num_or_size_splits=int(images.get_shape()[1]), value=images)
images = [tf.squeeze(img) for img in images]
if reuse_scope is None:
gen_images, gen_states, gen_mask, gen_mask_lists = construct_model(
images,
actions,
states,
iter_num=self.iter_num,
k=FLAGS.schedsamp_k,
use_state=FLAGS.use_state,
num_masks=FLAGS.num_masks,
cdna=FLAGS.model == 'CDNA',
dna=FLAGS.model == 'DNA',
stp=FLAGS.model == 'STP',
context_frames=FLAGS.context_frames)
else: # If it's a validation or test model.
with tf.variable_scope(reuse_scope, reuse=True):
gen_images, gen_states, gen_mask, gen_mask_lists = construct_model(
images,
actions,
states,
iter_num=self.iter_num,
k=FLAGS.schedsamp_k,
use_state=FLAGS.use_state,
num_masks=FLAGS.num_masks,
cdna=FLAGS.model == 'CDNA',
dna=FLAGS.model == 'DNA',
stp=FLAGS.model == 'STP',
context_frames=FLAGS.context_frames)
self.gen_mask = gen_mask
self.gen_mask_lists = gen_mask_lists
self.gen_images = gen_images
# L2 loss, PSNR for eval.
# loss, psnr_all = 0.0, 0.0
# for i, x, gx in zip(
# range(len(gen_images)), images[FLAGS.context_frames:], gen_images[FLAGS.context_frames - 1:]):
# recon_cost = mean_squared_error(x, gx)
# psnr_i = peak_signal_to_noise_ratio(x, gx)
# psnr_all += psnr_i
# summaries.append(tf.summary.scalar(prefix + '_recon_cost' + str(i), recon_cost))
# summaries.append(tf.summary.scalar(prefix + '_psnr' + str(i), psnr_i))
# loss += recon_cost
# for i, state, gen_state in zip(
# range(len(gen_states)), states[FLAGS.context_frames:], gen_states[FLAGS.context_frames - 1:]):
# state_cost = mean_squared_error(state, gen_state) * 1e-4
# summaries.append(tf.summary.scalar(prefix + '_state_cost' + str(i), state_cost))
# loss += state_cost
# summaries.append(tf.summary.scalar(prefix + '_psnr_all', psnr_all))
# self.psnr_all = psnr_all
# self.loss = loss = loss / np.float32(len(images) - FLAGS.context_frames)
# summaries.append(tf.summary.scalar(prefix + '_loss', loss))
# self.lr = tf.placeholder_with_default(FLAGS.learning_rate, ())
# self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)
# self.summ_op = tf.summary.merge(summaries)
def save(sess, saver, checkpoint_dir, step):
model_name = "{}.model".format(FLAGS.model)
model_dir = FLAGS.model
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)
def load(sess, saver, checkpoint_dir):
print(" [*] Reading checkpoints...")
model_dir = FLAGS.model
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
return True
else:
return False
import moviepy.editor as mpy
def npy_to_gif(npy, filename):
clip = mpy.ImageSequenceClip(list(npy), fps=10)
clip.write_gif(filename)
def interval_mapping(image, from_min, from_max, to_min, to_max):
from_range = from_max - from_min
to_range = to_max - to_min
scaled = np.array((image - from_min) / float(from_range), dtype=float)
return to_min + (scaled * to_range)
def main(unused_argv):
print('Constructing models and inputs.')
with tf.variable_scope('model', reuse=None) as training_scope:
images, actions, states = build_tfrecord_input(training=True, shuffle=False)
model = Model(images, actions, states, FLAGS.sequence_length, prefix='train')
with tf.variable_scope('val_model', reuse=None):
val_images, val_actions, val_states = build_tfrecord_input(training=False, shuffle=False)
val_model = Model(val_images, val_actions, val_states,
FLAGS.sequence_length, training_scope, prefix='val')
print('Constructing saver.')
# Make saver.
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
# Make training session.
sess = tf.InteractiveSession()
summary_writer = tf.summary.FileWriter(FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)
# if FLAGS.pretrained_model:
load(sess, saver, FLAGS.output_dir)
print(' [!] File loaded!')
tf.train.start_queue_runners(sess)
sess.run(tf.global_variables_initializer())
# train_images = sess.run(images)
# train_images = interval_mapping(train_images[0], 0.0, 1.0, 0, 255).astype('uint8')
# train_val_images = sess.run(val_images)
# train_val_images = interval_mapping(train_val_images[0], 0.0, 1.0, 0, 255).astype('uint8')
# npy_to_gif(train_images, '~/Dropbox/train_images.gif')
# npy_to_gif(train_val_images, '~/Dropbox/train_val_images.gif')
tf.logging.info('iteration number, cost')
feed_dict = {model.iter_num: -1}
gen_images = sess.run([model.gen_images], feed_dict)
# print(gen_images[0][0].shape)
sample = interval_mapping(gen_images[0][0], 0.0, 1.0, 0, 255).astype('uint8')
# Run training.
# print('Start training.')
# for itr in range(FLAGS.num_iterations):
# Generate new batch of data.
# feed_dict = {model.iter_num: np.float32(itr), model.lr: FLAGS.learning_rate}
# cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op], feed_dict)
# print('iter: ', itr, ', cost: ', cost)
# Print info: iteration #, cost.
# tf.logging.info(str(itr) + ' ' + str(cost))
# gen_images = sess.run([model.gen_test_images], feed_dict)
# print(gen_images[0].shape)
# print(gen_images[0][0].shape)
# sample_videos = sess.run(gen_images[0])
# sample = interval_mapping(gen_images[0][0], 0.0, 1.0, 0, 255).astype('uint8')
# print(sample.shape)
# # gen_test_images = tf.train.batch([sample], FLAGS.batch_size, num_threads=1, capacity=1)
# print(gen_test_images.shape)
npy_to_gif(sample, '~/sample.gif')
# for i in range(FLAGS.batch_size):
# video = gen_test_images[i]
# npy_to_gif(video, '~/train_' + str(i) + '.gif')
# if (itr) % VAL_INTERVAL == 2:
# Run through validation set.
# feed_dict = {val_model.lr: 0.0, val_model.iter_num: np.float32(itr)}
# _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op], feed_dict)
# summary_writer.add_summary(val_summary_str, itr)
# if (itr) % SAVE_INTERVAL == 2:
# print('Saving model.')
# tf.logging.info('Saving model.')
# saver.save(sess, FLAGS.output_dir + '/model' + str(itr))
# save(sess, saver, FLAGS.output_dir, itr)
# if (itr) % SUMMARY_INTERVAL:
# summary_writer.add_summary(summary_str, itr)
# if (itr) % TEST_INTERVAL == 2:
# FLAGS.batch_size = 25
# feed_dict = {model.iter_num: np.float32(itr), model.lr: FLAGS.learning_rate}
# gen_images = sess.run([model.gen_images], feed_dict)
# sample = []
# for i in range(len(gen_images[0][0])):
# sam = interval_mapping(gen_images[0][0][i], 0.0, 1.0, 0, 255).astype('uint8')
# sample.append(sam)
# from moviepy.editor import ImageSequenceClip
# image_clip = ImageSequenceClip(sample, fps=10)
# image_clip.to_gif("image_{}.gif".format(itr), fps=10)
# tf.logging.info('Saving model.')
# saver.save(sess, FLAGS.output_dir + '/model')
# save(sess, saver, FLAGS.output_dir, itr)
# print('Training complete')
#tf.logging.info('Training complete')
#tf.logging.flush()
if __name__ == '__main__':
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment