Skip to content

Instantly share code, notes, and snippets.

@d3rezz
Last active June 7, 2019 16:39
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d3rezz/7ad3e2b364e3460fca4de155c4a01db4 to your computer and use it in GitHub Desktop.
Save d3rezz/7ad3e2b364e3460fca4de155c4a01db4 to your computer and use it in GitHub Desktop.
Finetuning a tensorflow slim model (Resnet v1 50) with a dataset in TFRecord format
# Finetune a tensorflow slim model (Resnet v1 50) on the flowers dataset in TFRecord format
# TFRecord files created using the script from https://github.com/kwotsin/create_tfrecords
# Trainining done in a Keras like way, where training and validation accuracy are computed every epoch
# Download resnet checkpoint from: http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
import tensorflow as tf
import numpy as np
import os
import glob
from tqdm import tqdm
from tensorflow.contrib.slim.nets import resnet_v1
from models.research.slim.preprocessing import vgg_preprocessing
from tensorflow.contrib import slim
# constants
DATASET_DIR = "data/flowers/"
_FILE_PATTERN = 'flowers_%s_*.tfrecord'
batch_size = 16
num_epochs = 1
image_size = resnet_v1.resnet_v1.default_image_size
#Returns number of lines in text file
#used to determine number of classes in labels.txt file
def file_len(filename):
return sum(1 for line in open(filename))
# Return number of records in list of filenames
def get_num_records_tfecords(filenames):
c = 0
for fn in filenames:
for record in tf.python_io.tf_record_iterator(fn):
c += 1
return c
def _parse_function(example_proto):
features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
'image/height': tf.FixedLenFeature((), tf.int64),
'image/width': tf.FixedLenFeature((), tf.int64),
}
parsed_features = tf.parse_single_example(example_proto, features)
image = tf.image.decode_jpeg(parsed_features["image/encoded"])
width = tf.cast(parsed_features["image/width"], tf.int32)
height = tf.cast(parsed_features["image/height"], tf.int32)
label = tf.cast(parsed_features["image/class/label"], tf.int32)
# Reshape image data into the original shape
image = tf.reshape(image, [height, width, 3])
#Images need to have the same dimensions for feeding the network
image = vgg_preprocessing.preprocess_image(image, image_size, image_size)
return image, label
# Directory to save summaries to
logdir = "logs/"
os.makedirs(logdir, exist_ok=True)
graph = tf.Graph()
with graph.as_default():
tf.logging.set_verbosity(tf.logging.INFO)
# Load datasets
print("Loading dataset")
train_filenames = glob.glob(DATASET_DIR+_FILE_PATTERN % ("train"))
train_dataset = tf.data.TFRecordDataset(train_filenames)
train_dataset = train_dataset.map(_parse_function)
train_dataset = train_dataset.shuffle(buffer_size=10000) # don't forget to shuffle
batched_train_dataset = train_dataset.batch(batch_size)
val_filenames = glob.glob(DATASET_DIR+_FILE_PATTERN % ("validation"))
val_dataset = tf.data.TFRecordDataset(val_filenames)
val_dataset = val_dataset.map(_parse_function)
batched_val_dataset = val_dataset.batch(batch_size)
num_classes = file_len(os.path.join(DATASET_DIR,"labels.txt"))
num_train_records = get_num_records_tfecords(train_filenames)
print("Loaded train dataset with %d images belonging to %d classes" % (num_train_records, num_classes))
num_batches = np.ceil(num_train_records/batch_size)
num_val_records = get_num_records_tfecords(val_filenames)
print("Loaded val dataset with %d images belonging to %d classes" % (num_val_records, num_classes))
#iterator
iterator = tf.data.Iterator.from_structure(batched_train_dataset.output_types,
batched_train_dataset.output_shapes)
images, labels = iterator.get_next()
train_init_op = iterator.make_initializer(batched_train_dataset)
val_init_op = iterator.make_initializer(batched_val_dataset)
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=0.00001)):
logits, _ = resnet_v1.resnet_v1_50(images,
num_classes=num_classes,
is_training=True)
#TODO why do logits have dimension [batch size, 1, 1, 751] instead of [batch size, 751]?
logits = tf.squeeze(logits)
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=["resnet_v1_50/logits", "resnet_v1_50/AuxLogits"])
init_fn = tf.contrib.framework.assign_from_checkpoint_fn("resnet_v1_50.ckpt", variables_to_restore)
logits_variables = tf.contrib.framework.get_variables("resnet_v1_50/logits") + tf.contrib.framework.get_variables("resnet_v1_50/AuxLogits")
logits_init = tf.variables_initializer(logits_variables)
# Loss function:
predictions = tf.to_int32(tf.argmax(logits, 1))
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss()
temp = set(tf.all_variables())
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
logits_train_op = optimizer.minimize(total_loss, var_list=logits_variables) #use this op to only train the last layer
full_train_op = optimizer.minimize(total_loss) #use this op to train the whole network
#this needs to come after defining the training op
adam_init_op = tf.initialize_variables(set(tf.all_variables()) - temp)
# Define the metric and update operations (taken from http://ronny.rest/blog/post_2017_09_11_tf_metrics/)
tf_metric, tf_metric_update = tf.metrics.accuracy(labels, predictions, name="accuracy_metric")
# Isolate the variables stored behind the scenes by the metric operation
running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="accuracy_metric")
# Define initializer to initialize/reset running variables
running_vars_initializer = tf.variables_initializer(var_list=running_vars)
acc_summary = tf.summary.scalar('accuracy', tf_metric)
# To save the trained model
saver = tf.train.Saver()
tf.get_default_graph().finalize()
with tf.Session(graph=graph) as sess:
#Initializations
init_fn(sess)
sess.run(logits_init)
sess.run(adam_init_op)
print("Writing summaries to %s" % logdir)
train_writer = tf.summary.FileWriter(os.path.join(logdir,"train/"), sess.graph)
val_writer = tf.summary.FileWriter(os.path.join(logdir,"valid/"), sess.graph)
#Training
for epoch in range(num_epochs):
print('Starting training epoch %d / %d' % (epoch + 1, num_epochs))
# initialize the iterator with the training set
sess.run(train_init_op)
pbar = tqdm(total=num_batches) #progress bar showing how many batches remain
while True:
try:
# train on one batch of data
_ = sess.run(full_train_op)
pbar.update(1)
except tf.errors.OutOfRangeError:
break
pbar.close()
# Compute training and validation accuracy
sess.run(train_init_op)
# initialize/reset the accuracy running variables
sess.run(running_vars_initializer)
while True:
try:
sess.run(tf_metric_update)
except tf.errors.OutOfRangeError:
break
train_acc = sess.run(tf_metric)
summary = sess.run(acc_summary)
print('Train accuracy: %f' % train_acc)
train_writer.add_summary(summary,epoch +1)
train_writer.flush()
sess.run(val_init_op)
# initialize/reset the accuracy running variables
sess.run(running_vars_initializer)
while True:
try:
sess.run(tf_metric_update)
except tf.errors.OutOfRangeError:
break
# Calculate the score
val_acc = sess.run(tf_metric)
summary = sess.run(acc_summary)
print('Val accuracy: %f' % val_acc)
val_writer.add_summary(summary,epoch +1)
val_writer.flush()
#Save model
saver.save(sess, os.path.join(logdir, "trained_model.ckpt" ))
@michirup
Copy link

michirup commented Dec 20, 2018

After running this code I obtained the new checkpoint (properly composed of .meta .data .index and the checkpoint one). But when I used it to extract layers to retrieve features or predictions I encountered this error:

ValueError: The passed save_path is not a valid checkpoint

or sometimes

Key global_step not found in checkpoint

Any idea about the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment