Skip to content

Instantly share code, notes, and snippets.

@xiekuncn
Created May 17, 2017 08:22
Show Gist options
  • Save xiekuncn/90ef75357d8d9923d83b6ff3a520e195 to your computer and use it in GitHub Desktop.
Save xiekuncn/90ef75357d8d9923d83b6ff3a520e195 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import numpy as np
import os
import tensorflow.contrib.slim as slim
import tensorflow as tf
import json
import math
import time
BASE_PATH = "/tmp/heatmap/"
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
def euclidean_loss(y_true, y_pred):
loss = slim.losses.mean_squared_error(y_true, y_pred)
return loss
def heatmap(data, y_lables):
data = tf.image.resize_images(data, (248, 248))
with tf.variable_scope('heatmap', 'heatmap', [data, y_lables]):
with slim.arg_scope([slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=0.01)):
with tf.variable_scope("spatial_net"):
net = slim.conv2d(data, 128, [5, 5], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv1')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1', padding='SAME')
net = slim.conv2d(net, 128, [5, 5], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2', padding='SAME')
conv3 = slim.conv2d(net, 128, [5, 5], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv3')
net = slim.conv2d(conv3, 256, [5, 5], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv4')
net = slim.conv2d(net, 512, [9, 9], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv5')
net = slim.conv2d(net, 256, [1, 1], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv6')
conv7 = slim.conv2d(net, 256, [1, 1], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv7')
spatial_net = slim.conv2d(conv7, 2, [1, 1], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv8')
loss_spatial_net = euclidean_loss(y_lables, spatial_net)
with tf.variable_scope("spatial_fusion"):
concat = tf.concat([conv3, conv7], axis=3)
net = slim.conv2d(concat, 64, [7, 7], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv1')
net = slim.conv2d(net, 64, [13, 13], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv2')
net = slim.conv2d(net, 128, [13, 13], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv3')
net = slim.conv2d(net, 256, [1, 1], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv4')
spatial_fusion_net = slim.conv2d(net, 2, [1, 1], 1, padding='SAME',
activation_fn=tf.nn.relu,
scope='conv5')
loss_spatial_fusion = euclidean_loss(y_lables, spatial_fusion_net)
# total_loss = loss_spatial_net + 3 * loss_spatial_fusion
total_loss = tf.losses.get_total_loss(add_regularization_losses=False)
return spatial_net, spatial_fusion_net, loss_spatial_net, loss_spatial_fusion, total_loss
def read_and_decode(filename_queue,
image_size,
image_channel,
label_size,
label_channel,
normalized=False):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label_image_raw': tf.FixedLenFeature([], tf.string),
})
if normalized:
image = tf.decode_raw(features['image_raw'], tf.float32)
label = tf.decode_raw(features['label_image_raw'], tf.float32)
else:
image = tf.decode_raw(features['image_raw'], tf.int64)
label = tf.decode_raw(features['label_image_raw'], tf.int64)
image_length = image_size[0] * image_size[1] * image_channel
label_length = label_size[0] * label_size[1] * label_channel
image.set_shape([image_length])
label.set_shape([label_length])
image = tf.reshape(image, (image_size[0], image_size[1], image_channel))
label = tf.reshape(label, (label_size[0], label_size[1], label_channel))
if not normalized:
image = tf.cast(image, tf.float32) * (1. / 127.5) - 1
return image, label
def inputs(filename, batch_size, num_epochs, meta_info):
if not num_epochs:num_epochs = None
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
# Even when reading in multiple threads, share the filename
# queue.
image, label = read_and_decode(filename_queue,
image_size=meta_info["image_size"],
image_channel=meta_info["image_channel"],
label_size=meta_info["label_size"],
label_channel=meta_info["label_channel"],
normalized=meta_info["normalization"])
# Shuffle the examples and collect them into batch_size batches.
# (Internally uses a RandomShuffleQueue.)
# We run this in two threads to avoid being a bottleneck.
images, labels = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=5,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000
)
return images, labels
def print_summary():
print()
print("=" * 40)
sum_params = 0
for var in tf.trainable_variables():
num_param = 1
for elem in var.shape:
num_param *= elem.value
sum_params += num_param
print(var.name, "\t", var.shape, "\t", num_param)
print("-" * 40)
print("trainable variable:", len(tf.trainable_variables()))
print("trainable parameter:", sum_params)
print("=" * 40)
print()
def train_from_tfrecords(dataset_folder):
meta_path = os.path.join(dataset_folder, "meta.json")
with open(meta_path, "r") as f:
meta_info = json.load(f)
# valid_path = os.path.join(dataset_folder, meta_info["valid_file"])
train_path = os.path.join(dataset_folder, meta_info["train_file"])
batch_size = 8
num_epochs = 10000
current_rate = 10e-3
loss_queue_size = 50
min_learning_rate = 10e-8
num_train_samples = int(meta_info["length_train"])
with tf.Graph().as_default() as graph:
images, labels = inputs(train_path,
batch_size=batch_size,
num_epochs=num_epochs,
meta_info=meta_info)
spatial_net, spatial_fusion_net, loss_1, loss_2, total_loss = heatmap(images, labels)
learning_rate = tf.placeholder(tf.float64, shape=1, name="lr")
tf.summary.scalar("loss_spatial_net", loss_1)
tf.summary.scalar("loss_spatial_fusion", loss_2)
tf.summary.scalar("lr", learning_rate)
merged = tf.summary.merge_all()
print_summary()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
training_operation = slim.learning.create_train_op(total_loss, optimizer)
with tf.Session() as sess:
train_writer = tf.summary.FileWriter(BASE_PATH + "/log/", sess.graph)
exit(0)
# training_operation = optimizer.minimize(total_loss)
saver = tf.train.Saver(tf.trainable_variables())
save_path = BASE_PATH + 'alexnet.ckpt'
with tf.Session() as sess:
train_writer = tf.summary.FileWriter(BASE_PATH + "/log/", sess.graph)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
model_path = tf.train.latest_checkpoint(BASE_PATH)
if model_path is not None:
print("Reading Model parameters from %s" % model_path)
saver.restore(sess, model_path)
else:
print("Creating Model with fresh parameters.")
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
step = 0
last_epoch = 0
loss_queue = []
epoch_start_time = time.time()
try:
while not coord.should_stop():
step += 1
epoch = int(step * batch_size / num_train_samples)
if last_epoch != epoch:
last_epoch = epoch
loss_value, summary = sess.run([total_loss, merged],
feed_dict={learning_rate: current_rate})
loss_value *= 300
train_writer.add_summary(summary)
duration = time.time() - epoch_start_time
epoch_start_time = time.time()
print('Epoch %d: loss = %2f (%.3f sec)' % (epoch,
loss_value,
duration))
if len(loss_queue) >= loss_queue_size:
loss_queue.pop(0)
loss_queue.append(loss_value)
if (epoch + 1) % 1000 == 0:
snap_save_path = BASE_PATH + 'alexnet_epoch_{}_loss_{:.4f}_lr_{:.7f}.ckpt'
snap_save_path = snap_save_path.format(epoch,
loss_value,
current_rate)
print("saving a snap to {}.".format(snap_save_path))
saver.save(sess=sess, save_path=snap_save_path)
if math.fabs(current_rate - min_learning_rate) < 10e-10\
and (max(loss_queue) - min(loss_queue)) < 1:
print("** early stop **")
break
if len(loss_queue) == loss_queue_size \
and loss_queue[-1] + 1 > loss_queue[1] \
and current_rate > min_learning_rate:
current_rate *= 0.9
print("** change learning rate to {:.7f},"
" first loss: {:.2f}, last loss: {:.2f}".format(current_rate,
loss_queue[1],
loss_queue[-1]))
loss_queue = []
sess.run([training_operation],
feed_dict={learning_rate: current_rate})
except KeyboardInterrupt:
print('Done training for %d epochs, %d steps.' % (num_epochs, step))
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (num_epochs, step))
finally:
coord.request_stop()
coord.join(threads)
saver.save(sess=sess, save_path=save_path)
def explorer_dataset(dataset_folder):
meta_path = os.path.join(dataset_folder, "meta.json")
with open(meta_path, "r") as f:
meta_info = json.load(f)
# valid_path = os.path.join(dataset_folder, meta_info["valid_file"])
train_path = os.path.join(dataset_folder, meta_info["train_file"])
batch_size = 16
num_epochs = 1
with tf.Graph().as_default():
images, labels = inputs(train_path,
batch_size=batch_size,
num_epochs=num_epochs,
meta_info=meta_info)
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
result_path = BASE_PATH + "explorer_dataset/"
os.system("mkdir -p {}".format(result_path))
index = 0
while not coord.should_stop():
imgs, lbs = sess.run([images, labels])
for img, lb in zip(imgs, lbs):
img = cv2.resize(np.squeeze(img), (600, 600))
actual_points = [int((pt + 1) * 300) for pt in np.squeeze(lb)]
img = (img + 1) * 127.5
cv2.line(img,
(actual_points[0], actual_points[1]),
(actual_points[2], actual_points[3]),
(0, 255, 0), 2)
index += 1
cv2.imwrite(result_path + "/index_{}.png".format(index), img)
if index >= 20:
break
if index >= 20:
break
print("** explorer the train dataset at {}".format(result_path))
except tf.errors.OutOfRangeError:
pass
finally:
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":
if not os.path.exists(BASE_PATH):
os.makedirs(BASE_PATH)
# dataset_folder = "/dataset/door_tfrecords/img_size_600_600-dt_True-normal_True/"
dataset_folder = "/dataset/door/scan_to_image_3/img_size_248_248_3-label_size_62_62_2-dt_True-normal_True/"
train_from_tfrecords(dataset_folder)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment