Last active
February 16, 2019 14:05
-
-
Save yunsu3042/7e8145fe6508547556d5b640b4bedb46 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import os | |
import sys | |
import time | |
import tensorflow as tf | |
from google.protobuf import text_format | |
from tensorflow.python.platform import app | |
from delf import delf_config_pb2 | |
from delf import feature_extractor | |
from delf import feature_io | |
from delf import delf_v1 | |
import time | |
from nets import resnet_v1 | |
cmd_args = None | |
slim = tf.contrib.slim | |
# Extension of feature files. | |
_DELF_EXT = '.delf' | |
# Pace to report extraction log. | |
_STATUS_CHECK_ITERATIONS = 100 | |
def _ReadImageList(list_path): | |
"""Helper function to read image paths. | |
Args: | |
list_path: Path to list of images, one image path per line. | |
Returns: | |
image_paths: List of image paths. | |
""" | |
with tf.gfile.GFile(list_path, 'r') as f: | |
image_paths = f.readlines() | |
image_paths = [entry.rstrip() for entry in image_paths] | |
return image_paths | |
# added by ys | |
def label_to_int(labels): | |
result = [] | |
label_book = {} | |
numbering = 0 | |
for label in labels: | |
try: | |
result.append(label_book[label]) | |
except: | |
label_book[label] = numbering | |
result.append(numbering) | |
numbering += 1 | |
return result | |
def list_images(directory, convert=False): | |
""" | |
Get all the images and labels in directory/label/*.jpg | |
""" | |
labels = os.listdir(directory) | |
# Sort the labels so that training and validation get them in the same order | |
# labels.sort() | |
files_and_labels = [] | |
for label in labels: | |
for f in os.listdir(os.path.join(directory, label)): | |
files_and_labels.append((os.path.join(directory, label, f), label)) | |
filenames, labels = zip(*files_and_labels) | |
filenames = list(filenames) | |
labels = list(labels) | |
# [hs] | |
# labels = map(int, labels) | |
# labels = [int(label) for label in labels] | |
"""replace codes""" | |
int_labels = label_to_int(labels) | |
return filenames, int_labels | |
def _parse_function(filename, label): | |
image_string = tf.read_file(filename) | |
image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1) | |
image = tf.cast(image_decoded, tf.float32) | |
return image, label | |
def read_varaibles_from_file(filepath): | |
with open(filepath) as f: | |
lines = f.readlines() | |
lines = [x.strip() for x in lines] | |
return lines | |
def check_accuracy(sess, correct_prediction, dataset_init_op): | |
""" | |
Check the accuracy of the model on either train or val (depending on dataset_init_op). | |
""" | |
# Initialize the correct dataset | |
sess.run(dataset_init_op) | |
num_correct, num_samples = 0, 0 | |
count = 0 | |
while True: | |
try: | |
correct_pred = sess.run(correct_prediction) | |
num_correct += correct_pred.sum() | |
num_samples += correct_pred.shape[0] | |
if count % 100 == 0: | |
acc = float(num_correct) / num_samples | |
# print("finished reading " + str(num_samples) + " data") | |
# print("current accuracy: " + str(acc)) | |
count += 1 | |
except tf.errors.OutOfRangeError: | |
break | |
# Return the fraction of datapoints that were correctly classified | |
acc = float(num_correct) / num_samples | |
return acc | |
def build_model(images, num_classes, is_training=True, reuse=None): | |
model = delf_v1.DelfV1() | |
net, end_points = model.GetResnet50Subnetwork( | |
images, global_pool=True, is_training=is_training, reuse=reuse) | |
with slim.arg_scope( | |
resnet_v1.resnet_arg_scope( | |
weight_decay=0.0001, batch_norm_scale=True)): | |
with slim.arg_scope([slim.batch_norm], is_training=True): | |
feature_map = end_points['resnet_v1_50/block3'] | |
feature_map = slim.conv2d( | |
feature_map, | |
512, | |
1, | |
rate=1, | |
activation_fn=tf.nn.relu, | |
scope='conv1') | |
feature_map = tf.reduce_mean(feature_map, [1, 2]) | |
feature_map = tf.expand_dims(tf.expand_dims(feature_map, 1), 2) | |
logits = slim.conv2d( | |
feature_map, | |
num_classes, [1, 1], | |
activation_fn=None, | |
normalizer_fn=None, | |
scope='logits') | |
logits = tf.squeeze(logits, [1, 2], name='spatial_squeeze') | |
return logits | |
# Local train directory | |
train_data = "/home/soma03/projects/data/stanford_clean/train" | |
# Set the parameters here | |
batch_size = 128 | |
num_preprocess_threads = 32 | |
learning_rate = 0.00005 | |
epochs = 1000 | |
# Get the list of training and validation data | |
convert = False | |
train_filenames, train_labels = list_images(train_data, convert) | |
num_classes = len(set(train_labels)) | |
print("there are " + str(num_classes) + " classes") | |
# Get the number of data in total for batch calculation later | |
num_train_data = len(train_filenames) | |
# Set up the training data pipeline | |
train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels)) | |
train_dataset = train_dataset.map(_parse_function, num_parallel_calls=num_preprocess_threads).prefetch(batch_size) | |
train_dataset = train_dataset.shuffle(buffer_size=300000) | |
batched_train_dataset = train_dataset.batch(batch_size) | |
# Set up the iterator | |
iterator = tf.data.Iterator.from_structure(batched_train_dataset.output_types, | |
batched_train_dataset.output_shapes) | |
# Build the model | |
images, labels = iterator.get_next() | |
# logits = build_model(images, num_classes) | |
"""revised by ys""" | |
# with slim.arg_scope(resnet_v2.resnet_arg_scope()) | |
with slim.arg_scope(resnet_v1.resnet_arg_scope(use_batch_norm=True)): | |
logits = build_model(images, num_classes) | |
# Add loss function | |
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) | |
loss = tf.losses.get_total_loss() | |
# Get pre-trained weights | |
variables_path = 'load_names.txt' | |
variables_to_be_load = read_varaibles_from_file(variables_path) | |
# restore_var = [v for v in tf.global_variables() if v.name[:-2] | |
# in variables_to_be_load] | |
restore_var = [v for v in tf.global_variables() if 'resnet' in v.name] | |
train_variables = [ | |
v for v in tf.global_variables() if 'resnet' not in v.name] | |
train_init_op = iterator.make_initializer(batched_train_dataset) | |
att = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) | |
# output = att.minimize(loss, var_list=train_variables) | |
output = att.minimize(loss) | |
# Evaluation metrics | |
prediction = tf.to_int32(tf.argmax(logits, 1)) | |
correct_prediction = tf.equal(prediction, labels) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
init_op = tf.global_variables_initializer() | |
"""revised by ys""" | |
# to restore variable Saver needs variable list what to restore. | |
saver = tf.train.Saver(restore_var) | |
saver2 = tf.train.Saver() | |
num_batches = int(num_train_data / batch_size) + 1 | |
with tf.Session() as sess: | |
sess.run(init_op) | |
"""revised by ys""" | |
# You can get resnet_v1_50.ckpt /home/projects/ys/codes/delf+knn/ | |
saver.restore(sess, "resnet_v1_50.ckpt") | |
print("weights loaded") | |
print("there are " + str(num_batches) + " batches") | |
for epoch in range(epochs): | |
sess.run(train_init_op) | |
print('Starting epoch %d / %d' % (epoch + 1, epochs)) | |
t = time.time() | |
acc_at_each_epoch = [] | |
for batch in range(num_batches): | |
# sess.run(train_init_op) | |
_, acc, batch_loss = sess.run([output, accuracy, loss]) | |
acc_at_each_epoch.append(acc) | |
print(acc) | |
if batch % 100 == 0: | |
print("At batch " + str(batch)) | |
accumated_acc = sum(acc_at_each_epoch) / \ | |
float(len(acc_at_each_epoch)) | |
print("accuracy accumalated: " + str(accumated_acc)) | |
print("loss at this batch: " + str(batch_loss)) | |
elapsed = time.time() - t | |
print("it takes " + str(elapsed) | |
+ " seconds to train this 100 batches") | |
t = time.time() | |
print("==========================================================") | |
# train_acc = check_accuracy(sess, correct_prediction, train_init_op) | |
# print("training acc is: " + str(train_acc)) | |
# val_acc = check_accuracy(sess, correct_prediction, val_init_op) | |
# print("validation acc is: " + str(val_acc)) | |
# sess.run(train_init_op) | |
print("epoch: " + str(epoch) + " is done!") | |
# train_acc = check_accuracy(sess, correct_prediction, train_init_op) | |
# val_acc = check_accuracy(sess, correct_prediction, val_init_op) | |
# print('Train accuracy: %f' % train_acc) | |
print("validation acc is: " + str(val_acc) + "\n") | |
saver2.save(sess, "my_weights/trained_resnet_model") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment