Created
May 17, 2017 08:22
-
-
Save xiekuncn/90ef75357d8d9923d83b6ff3a520e195 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import cv2 | |
import numpy as np | |
import os | |
import tensorflow.contrib.slim as slim | |
import tensorflow as tf | |
import json | |
import math | |
import time | |
BASE_PATH = "/tmp/heatmap/" | |
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) | |
def euclidean_loss(y_true, y_pred): | |
loss = slim.losses.mean_squared_error(y_true, y_pred) | |
return loss | |
def heatmap(data, y_lables): | |
data = tf.image.resize_images(data, (248, 248)) | |
with tf.variable_scope('heatmap', 'heatmap', [data, y_lables]): | |
with slim.arg_scope([slim.conv2d], | |
weights_initializer=tf.truncated_normal_initializer(stddev=0.01)): | |
with tf.variable_scope("spatial_net"): | |
net = slim.conv2d(data, 128, [5, 5], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv1') | |
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1', padding='SAME') | |
net = slim.conv2d(net, 128, [5, 5], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv2') | |
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2', padding='SAME') | |
conv3 = slim.conv2d(net, 128, [5, 5], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv3') | |
net = slim.conv2d(conv3, 256, [5, 5], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv4') | |
net = slim.conv2d(net, 512, [9, 9], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv5') | |
net = slim.conv2d(net, 256, [1, 1], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv6') | |
conv7 = slim.conv2d(net, 256, [1, 1], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv7') | |
spatial_net = slim.conv2d(conv7, 2, [1, 1], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv8') | |
loss_spatial_net = euclidean_loss(y_lables, spatial_net) | |
with tf.variable_scope("spatial_fusion"): | |
concat = tf.concat([conv3, conv7], axis=3) | |
net = slim.conv2d(concat, 64, [7, 7], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv1') | |
net = slim.conv2d(net, 64, [13, 13], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv2') | |
net = slim.conv2d(net, 128, [13, 13], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv3') | |
net = slim.conv2d(net, 256, [1, 1], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv4') | |
spatial_fusion_net = slim.conv2d(net, 2, [1, 1], 1, padding='SAME', | |
activation_fn=tf.nn.relu, | |
scope='conv5') | |
loss_spatial_fusion = euclidean_loss(y_lables, spatial_fusion_net) | |
# total_loss = loss_spatial_net + 3 * loss_spatial_fusion | |
total_loss = tf.losses.get_total_loss(add_regularization_losses=False) | |
return spatial_net, spatial_fusion_net, loss_spatial_net, loss_spatial_fusion, total_loss | |
def read_and_decode(filename_queue, | |
image_size, | |
image_channel, | |
label_size, | |
label_channel, | |
normalized=False): | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example( | |
serialized_example, | |
# Defaults are not specified since both keys are required. | |
features={ | |
'image_raw': tf.FixedLenFeature([], tf.string), | |
'label_image_raw': tf.FixedLenFeature([], tf.string), | |
}) | |
if normalized: | |
image = tf.decode_raw(features['image_raw'], tf.float32) | |
label = tf.decode_raw(features['label_image_raw'], tf.float32) | |
else: | |
image = tf.decode_raw(features['image_raw'], tf.int64) | |
label = tf.decode_raw(features['label_image_raw'], tf.int64) | |
image_length = image_size[0] * image_size[1] * image_channel | |
label_length = label_size[0] * label_size[1] * label_channel | |
image.set_shape([image_length]) | |
label.set_shape([label_length]) | |
image = tf.reshape(image, (image_size[0], image_size[1], image_channel)) | |
label = tf.reshape(label, (label_size[0], label_size[1], label_channel)) | |
if not normalized: | |
image = tf.cast(image, tf.float32) * (1. / 127.5) - 1 | |
return image, label | |
def inputs(filename, batch_size, num_epochs, meta_info): | |
if not num_epochs:num_epochs = None | |
with tf.name_scope('input'): | |
filename_queue = tf.train.string_input_producer( | |
[filename], num_epochs=num_epochs) | |
# Even when reading in multiple threads, share the filename | |
# queue. | |
image, label = read_and_decode(filename_queue, | |
image_size=meta_info["image_size"], | |
image_channel=meta_info["image_channel"], | |
label_size=meta_info["label_size"], | |
label_channel=meta_info["label_channel"], | |
normalized=meta_info["normalization"]) | |
# Shuffle the examples and collect them into batch_size batches. | |
# (Internally uses a RandomShuffleQueue.) | |
# We run this in two threads to avoid being a bottleneck. | |
images, labels = tf.train.shuffle_batch( | |
[image, label], | |
batch_size=batch_size, | |
num_threads=5, | |
capacity=1000 + 3 * batch_size, | |
min_after_dequeue=1000 | |
) | |
return images, labels | |
def print_summary(): | |
print() | |
print("=" * 40) | |
sum_params = 0 | |
for var in tf.trainable_variables(): | |
num_param = 1 | |
for elem in var.shape: | |
num_param *= elem.value | |
sum_params += num_param | |
print(var.name, "\t", var.shape, "\t", num_param) | |
print("-" * 40) | |
print("trainable variable:", len(tf.trainable_variables())) | |
print("trainable parameter:", sum_params) | |
print("=" * 40) | |
print() | |
def train_from_tfrecords(dataset_folder): | |
meta_path = os.path.join(dataset_folder, "meta.json") | |
with open(meta_path, "r") as f: | |
meta_info = json.load(f) | |
# valid_path = os.path.join(dataset_folder, meta_info["valid_file"]) | |
train_path = os.path.join(dataset_folder, meta_info["train_file"]) | |
batch_size = 8 | |
num_epochs = 10000 | |
current_rate = 10e-3 | |
loss_queue_size = 50 | |
min_learning_rate = 10e-8 | |
num_train_samples = int(meta_info["length_train"]) | |
with tf.Graph().as_default() as graph: | |
images, labels = inputs(train_path, | |
batch_size=batch_size, | |
num_epochs=num_epochs, | |
meta_info=meta_info) | |
spatial_net, spatial_fusion_net, loss_1, loss_2, total_loss = heatmap(images, labels) | |
learning_rate = tf.placeholder(tf.float64, shape=1, name="lr") | |
tf.summary.scalar("loss_spatial_net", loss_1) | |
tf.summary.scalar("loss_spatial_fusion", loss_2) | |
tf.summary.scalar("lr", learning_rate) | |
merged = tf.summary.merge_all() | |
print_summary() | |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) | |
training_operation = slim.learning.create_train_op(total_loss, optimizer) | |
with tf.Session() as sess: | |
train_writer = tf.summary.FileWriter(BASE_PATH + "/log/", sess.graph) | |
exit(0) | |
# training_operation = optimizer.minimize(total_loss) | |
saver = tf.train.Saver(tf.trainable_variables()) | |
save_path = BASE_PATH + 'alexnet.ckpt' | |
with tf.Session() as sess: | |
train_writer = tf.summary.FileWriter(BASE_PATH + "/log/", sess.graph) | |
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) | |
sess.run(init_op) | |
model_path = tf.train.latest_checkpoint(BASE_PATH) | |
if model_path is not None: | |
print("Reading Model parameters from %s" % model_path) | |
saver.restore(sess, model_path) | |
else: | |
print("Creating Model with fresh parameters.") | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
step = 0 | |
last_epoch = 0 | |
loss_queue = [] | |
epoch_start_time = time.time() | |
try: | |
while not coord.should_stop(): | |
step += 1 | |
epoch = int(step * batch_size / num_train_samples) | |
if last_epoch != epoch: | |
last_epoch = epoch | |
loss_value, summary = sess.run([total_loss, merged], | |
feed_dict={learning_rate: current_rate}) | |
loss_value *= 300 | |
train_writer.add_summary(summary) | |
duration = time.time() - epoch_start_time | |
epoch_start_time = time.time() | |
print('Epoch %d: loss = %2f (%.3f sec)' % (epoch, | |
loss_value, | |
duration)) | |
if len(loss_queue) >= loss_queue_size: | |
loss_queue.pop(0) | |
loss_queue.append(loss_value) | |
if (epoch + 1) % 1000 == 0: | |
snap_save_path = BASE_PATH + 'alexnet_epoch_{}_loss_{:.4f}_lr_{:.7f}.ckpt' | |
snap_save_path = snap_save_path.format(epoch, | |
loss_value, | |
current_rate) | |
print("saving a snap to {}.".format(snap_save_path)) | |
saver.save(sess=sess, save_path=snap_save_path) | |
if math.fabs(current_rate - min_learning_rate) < 10e-10\ | |
and (max(loss_queue) - min(loss_queue)) < 1: | |
print("** early stop **") | |
break | |
if len(loss_queue) == loss_queue_size \ | |
and loss_queue[-1] + 1 > loss_queue[1] \ | |
and current_rate > min_learning_rate: | |
current_rate *= 0.9 | |
print("** change learning rate to {:.7f}," | |
" first loss: {:.2f}, last loss: {:.2f}".format(current_rate, | |
loss_queue[1], | |
loss_queue[-1])) | |
loss_queue = [] | |
sess.run([training_operation], | |
feed_dict={learning_rate: current_rate}) | |
except KeyboardInterrupt: | |
print('Done training for %d epochs, %d steps.' % (num_epochs, step)) | |
except tf.errors.OutOfRangeError: | |
print('Done training for %d epochs, %d steps.' % (num_epochs, step)) | |
finally: | |
coord.request_stop() | |
coord.join(threads) | |
saver.save(sess=sess, save_path=save_path) | |
def explorer_dataset(dataset_folder): | |
meta_path = os.path.join(dataset_folder, "meta.json") | |
with open(meta_path, "r") as f: | |
meta_info = json.load(f) | |
# valid_path = os.path.join(dataset_folder, meta_info["valid_file"]) | |
train_path = os.path.join(dataset_folder, meta_info["train_file"]) | |
batch_size = 16 | |
num_epochs = 1 | |
with tf.Graph().as_default(): | |
images, labels = inputs(train_path, | |
batch_size=batch_size, | |
num_epochs=num_epochs, | |
meta_info=meta_info) | |
with tf.Session() as sess: | |
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) | |
sess.run(init_op) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
try: | |
result_path = BASE_PATH + "explorer_dataset/" | |
os.system("mkdir -p {}".format(result_path)) | |
index = 0 | |
while not coord.should_stop(): | |
imgs, lbs = sess.run([images, labels]) | |
for img, lb in zip(imgs, lbs): | |
img = cv2.resize(np.squeeze(img), (600, 600)) | |
actual_points = [int((pt + 1) * 300) for pt in np.squeeze(lb)] | |
img = (img + 1) * 127.5 | |
cv2.line(img, | |
(actual_points[0], actual_points[1]), | |
(actual_points[2], actual_points[3]), | |
(0, 255, 0), 2) | |
index += 1 | |
cv2.imwrite(result_path + "/index_{}.png".format(index), img) | |
if index >= 20: | |
break | |
if index >= 20: | |
break | |
print("** explorer the train dataset at {}".format(result_path)) | |
except tf.errors.OutOfRangeError: | |
pass | |
finally: | |
coord.request_stop() | |
coord.join(threads) | |
if __name__ == "__main__": | |
if not os.path.exists(BASE_PATH): | |
os.makedirs(BASE_PATH) | |
# dataset_folder = "/dataset/door_tfrecords/img_size_600_600-dt_True-normal_True/" | |
dataset_folder = "/dataset/door/scan_to_image_3/img_size_248_248_3-label_size_62_62_2-dt_True-normal_True/" | |
train_from_tfrecords(dataset_folder) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment