Skip to content

Instantly share code, notes, and snippets.

@yunsu3042
Last active February 16, 2019 14:05
Show Gist options
  • Save yunsu3042/7e8145fe6508547556d5b640b4bedb46 to your computer and use it in GitHub Desktop.
Save yunsu3042/7e8145fe6508547556d5b640b4bedb46 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import time
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import app
from delf import delf_config_pb2
from delf import feature_extractor
from delf import feature_io
from delf import delf_v1
import time
from nets import resnet_v1
cmd_args = None
slim = tf.contrib.slim
# Extension of feature files.
_DELF_EXT = '.delf'
# Pace to report extraction log.
_STATUS_CHECK_ITERATIONS = 100
def _ReadImageList(list_path):
"""Helper function to read image paths.
Args:
list_path: Path to list of images, one image path per line.
Returns:
image_paths: List of image paths.
"""
with tf.gfile.GFile(list_path, 'r') as f:
image_paths = f.readlines()
image_paths = [entry.rstrip() for entry in image_paths]
return image_paths
# added by ys
def label_to_int(labels):
result = []
label_book = {}
numbering = 0
for label in labels:
try:
result.append(label_book[label])
except:
label_book[label] = numbering
result.append(numbering)
numbering += 1
return result
def list_images(directory, convert=False):
"""
Get all the images and labels in directory/label/*.jpg
"""
labels = os.listdir(directory)
# Sort the labels so that training and validation get them in the same order
# labels.sort()
files_and_labels = []
for label in labels:
for f in os.listdir(os.path.join(directory, label)):
files_and_labels.append((os.path.join(directory, label, f), label))
filenames, labels = zip(*files_and_labels)
filenames = list(filenames)
labels = list(labels)
# [hs]
# labels = map(int, labels)
# labels = [int(label) for label in labels]
"""replace codes"""
int_labels = label_to_int(labels)
return filenames, int_labels
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1)
image = tf.cast(image_decoded, tf.float32)
return image, label
def read_varaibles_from_file(filepath):
with open(filepath) as f:
lines = f.readlines()
lines = [x.strip() for x in lines]
return lines
def check_accuracy(sess, correct_prediction, dataset_init_op):
"""
Check the accuracy of the model on either train or val (depending on dataset_init_op).
"""
# Initialize the correct dataset
sess.run(dataset_init_op)
num_correct, num_samples = 0, 0
count = 0
while True:
try:
correct_pred = sess.run(correct_prediction)
num_correct += correct_pred.sum()
num_samples += correct_pred.shape[0]
if count % 100 == 0:
acc = float(num_correct) / num_samples
# print("finished reading " + str(num_samples) + " data")
# print("current accuracy: " + str(acc))
count += 1
except tf.errors.OutOfRangeError:
break
# Return the fraction of datapoints that were correctly classified
acc = float(num_correct) / num_samples
return acc
def build_model(images, num_classes, is_training=True, reuse=None):
model = delf_v1.DelfV1()
net, end_points = model.GetResnet50Subnetwork(
images, global_pool=True, is_training=is_training, reuse=reuse)
with slim.arg_scope(
resnet_v1.resnet_arg_scope(
weight_decay=0.0001, batch_norm_scale=True)):
with slim.arg_scope([slim.batch_norm], is_training=True):
feature_map = end_points['resnet_v1_50/block3']
feature_map = slim.conv2d(
feature_map,
512,
1,
rate=1,
activation_fn=tf.nn.relu,
scope='conv1')
feature_map = tf.reduce_mean(feature_map, [1, 2])
feature_map = tf.expand_dims(tf.expand_dims(feature_map, 1), 2)
logits = slim.conv2d(
feature_map,
num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
scope='logits')
logits = tf.squeeze(logits, [1, 2], name='spatial_squeeze')
return logits
# Local train directory
train_data = "/home/soma03/projects/data/stanford_clean/train"
# Set the parameters here
batch_size = 128
num_preprocess_threads = 32
learning_rate = 0.00005
epochs = 1000
# Get the list of training and validation data
convert = False
train_filenames, train_labels = list_images(train_data, convert)
num_classes = len(set(train_labels))
print("there are " + str(num_classes) + " classes")
# Get the number of data in total for batch calculation later
num_train_data = len(train_filenames)
# Set up the training data pipeline
train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
train_dataset = train_dataset.map(_parse_function, num_parallel_calls=num_preprocess_threads).prefetch(batch_size)
train_dataset = train_dataset.shuffle(buffer_size=300000)
batched_train_dataset = train_dataset.batch(batch_size)
# Set up the iterator
iterator = tf.data.Iterator.from_structure(batched_train_dataset.output_types,
batched_train_dataset.output_shapes)
# Build the model
images, labels = iterator.get_next()
# logits = build_model(images, num_classes)
"""revised by ys"""
# with slim.arg_scope(resnet_v2.resnet_arg_scope())
with slim.arg_scope(resnet_v1.resnet_arg_scope(use_batch_norm=True)):
logits = build_model(images, num_classes)
# Add loss function
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
loss = tf.losses.get_total_loss()
# Get pre-trained weights
variables_path = 'load_names.txt'
variables_to_be_load = read_varaibles_from_file(variables_path)
# restore_var = [v for v in tf.global_variables() if v.name[:-2]
# in variables_to_be_load]
restore_var = [v for v in tf.global_variables() if 'resnet' in v.name]
train_variables = [
v for v in tf.global_variables() if 'resnet' not in v.name]
train_init_op = iterator.make_initializer(batched_train_dataset)
att = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
# output = att.minimize(loss, var_list=train_variables)
output = att.minimize(loss)
# Evaluation metrics
prediction = tf.to_int32(tf.argmax(logits, 1))
correct_prediction = tf.equal(prediction, labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init_op = tf.global_variables_initializer()
"""revised by ys"""
# to restore variable Saver needs variable list what to restore.
saver = tf.train.Saver(restore_var)
saver2 = tf.train.Saver()
num_batches = int(num_train_data / batch_size) + 1
with tf.Session() as sess:
sess.run(init_op)
"""revised by ys"""
# You can get resnet_v1_50.ckpt /home/projects/ys/codes/delf+knn/
saver.restore(sess, "resnet_v1_50.ckpt")
print("weights loaded")
print("there are " + str(num_batches) + " batches")
for epoch in range(epochs):
sess.run(train_init_op)
print('Starting epoch %d / %d' % (epoch + 1, epochs))
t = time.time()
acc_at_each_epoch = []
for batch in range(num_batches):
# sess.run(train_init_op)
_, acc, batch_loss = sess.run([output, accuracy, loss])
acc_at_each_epoch.append(acc)
print(acc)
if batch % 100 == 0:
print("At batch " + str(batch))
accumated_acc = sum(acc_at_each_epoch) / \
float(len(acc_at_each_epoch))
print("accuracy accumalated: " + str(accumated_acc))
print("loss at this batch: " + str(batch_loss))
elapsed = time.time() - t
print("it takes " + str(elapsed)
+ " seconds to train this 100 batches")
t = time.time()
print("==========================================================")
# train_acc = check_accuracy(sess, correct_prediction, train_init_op)
# print("training acc is: " + str(train_acc))
# val_acc = check_accuracy(sess, correct_prediction, val_init_op)
# print("validation acc is: " + str(val_acc))
# sess.run(train_init_op)
print("epoch: " + str(epoch) + " is done!")
# train_acc = check_accuracy(sess, correct_prediction, train_init_op)
# val_acc = check_accuracy(sess, correct_prediction, val_init_op)
# print('Train accuracy: %f' % train_acc)
print("validation acc is: " + str(val_acc) + "\n")
saver2.save(sess, "my_weights/trained_resnet_model")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment