Skip to content

Instantly share code, notes, and snippets.

@eldar
Last active September 11, 2017 06:20
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save eldar/0ecc058670be340b92e5a1044dc8a089 to your computer and use it in GitHub Desktop.
Save eldar/0ecc058670be340b92e5a1044dc8a089 to your computer and use it in GitHub Desktop.
import datetime as dt
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1
import threading
from PoseDataset import PoseDataset
from TrainParams import TrainParams
dataset = MyDataset()
train_param = TrainParams()
QUEUE_SIZE = 50
num_classes = 14
inputs = tf.placeholder(tf.float32, shape=[1, None, None, 3])
data_labels = tf.placeholder(tf.float32, shape=[1, None, None, num_classes])
q = tf.FIFOQueue(QUEUE_SIZE, [tf.float32, tf.float32])
enqueue_op = q.enqueue([inputs, data_labels])
inputs_batch, targets_batch = q.dequeue()
inputs_batch.set_shape([1, None, None, 3])
targets_batch.set_shape([1, None, None, num_classes])
def load_and_enqueue(sess, enqueue_op, coord, dataset):
while not coord.should_stop():
batch = dataset.next_batch()
sess.run(enqueue_op, feed_dict={inputs: batch['inputs'],
data_labels: batch['data_labels']})
with slim.arg_scope(resnet_v1.resnet_arg_scope(False)):
mean = tf.constant([123.68, 116.779, 103.939],
dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
im_centered = inputs_batch - mean
net, end_points = resnet_v1.resnet_v1_101(im_centered,
global_pool=False, output_stride=16)
pred_upconv = slim.conv2d_transpose(net, num_classes,
kernel_size = [3, 3],
stride = 2,
padding='SAME')
loss = slim.losses.sigmoid_cross_entropy(pred_upconv, targets_batch)
model_path = 'resnet_v1_101.ckpt'
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(tf.initialize_local_variables())
# Restore variables from disk.
variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, model_path)
coord = tf.train.Coordinator()
t = threading.Thread(target=load_and_enqueue, args=(sess,enqueue_op,coord,dataset))
t.start()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001)
train_op = optimizer.minimize(loss)
for it in range(10000):
sess.run(train_op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment