Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Example TensorFlow script for fine-tuning a VGG model (uses tf.contrib.data)
"""
Example TensorFlow script for finetuning a VGG model on your own data.
Uses tf.contrib.data module which is in release v1.2
Based on PyTorch example from Justin Johnson
(https://gist.github.com/jcjohnson/6e41e8512c17eae5da50aebef3378a4c)
Required packages: tensorflow (v1.2)
Download the weights trained on ImageNet for VGG:
```
wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
tar -xvf vgg_16_2016_08_28.tar.gz
rm vgg_16_2016_08_28.tar.gz
```
For this example we will use a tiny dataset of images from the COCO dataset.
We have chosen eight types of animals (bear, bird, cat, dog, giraffe, horse,
sheep, and zebra); for each of these categories we have selected 100 training
images and 25 validation images from the COCO dataset. You can download and
unpack the data (176 MB) by running:
```
wget cs231n.stanford.edu/coco-animals.zip
unzip coco-animals.zip
rm coco-animals.zip
```
The training data is stored on disk; each category has its own folder on disk
and the images for that category are stored as .jpg files in the category folder.
In other words, the directory structure looks something like this:
coco-animals/
train/
bear/
COCO_train2014_000000005785.jpg
COCO_train2014_000000015870.jpg
[...]
bird/
cat/
dog/
giraffe/
horse/
sheep/
zebra/
val/
bear/
bird/
cat/
dog/
giraffe/
horse/
sheep/
zebra/
"""
import argparse
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets
parser = argparse.ArgumentParser()
parser.add_argument('--train_dir', default='coco-animals/train')
parser.add_argument('--val_dir', default='coco-animals/val')
parser.add_argument('--model_path', default='vgg_16.ckpt', type=str)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--num_epochs1', default=10, type=int)
parser.add_argument('--num_epochs2', default=10, type=int)
parser.add_argument('--learning_rate1', default=1e-3, type=float)
parser.add_argument('--learning_rate2', default=1e-5, type=float)
parser.add_argument('--dropout_keep_prob', default=0.5, type=float)
parser.add_argument('--weight_decay', default=5e-4, type=float)
VGG_MEAN = [123.68, 116.78, 103.94]
def list_images(directory):
"""
Get all the images and labels in directory/label/*.jpg
"""
labels = os.listdir(directory)
# Sort the labels so that training and validation get them in the same order
labels.sort()
files_and_labels = []
for label in labels:
for f in os.listdir(os.path.join(directory, label)):
files_and_labels.append((os.path.join(directory, label, f), label))
filenames, labels = zip(*files_and_labels)
filenames = list(filenames)
labels = list(labels)
unique_labels = list(set(labels))
label_to_int = {}
for i, label in enumerate(unique_labels):
label_to_int[label] = i
labels = [label_to_int[l] for l in labels]
return filenames, labels
def check_accuracy(sess, correct_prediction, is_training, dataset_init_op):
"""
Check the accuracy of the model on either train or val (depending on dataset_init_op).
"""
# Initialize the correct dataset
sess.run(dataset_init_op)
num_correct, num_samples = 0, 0
while True:
try:
correct_pred = sess.run(correct_prediction, {is_training: False})
num_correct += correct_pred.sum()
num_samples += correct_pred.shape[0]
except tf.errors.OutOfRangeError:
break
# Return the fraction of datapoints that were correctly classified
acc = float(num_correct) / num_samples
return acc
def main(args):
# Get the list of filenames and corresponding list of labels for training et validation
train_filenames, train_labels = list_images(args.train_dir)
val_filenames, val_labels = list_images(args.val_dir)
assert set(train_labels) == set(val_labels),\
"Train and val labels don't correspond:\n{}\n{}".format(set(train_labels),
set(val_labels))
num_classes = len(set(train_labels))
# --------------------------------------------------------------------------
# In TensorFlow, you first want to define the computation graph with all the
# necessary operations: loss, training op, accuracy...
# Any tensor created in the `graph.as_default()` scope will be part of `graph`
graph = tf.Graph()
with graph.as_default():
# Standard preprocessing for VGG on ImageNet taken from here:
# https://github.com/tensorflow/models/blob/master/research/slim/preprocessing/vgg_preprocessing.py
# Also see the VGG paper for more details: https://arxiv.org/pdf/1409.1556.pdf
# Preprocessing (for both training and validation):
# (1) Decode the image from jpg format
# (2) Resize the image so its smaller side is 256 pixels long
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1)
image = tf.cast(image_decoded, tf.float32)
smallest_side = 256.0
height, width = tf.shape(image)[0], tf.shape(image)[1]
height = tf.to_float(height)
width = tf.to_float(width)
scale = tf.cond(tf.greater(height, width),
lambda: smallest_side / width,
lambda: smallest_side / height)
new_height = tf.to_int32(height * scale)
new_width = tf.to_int32(width * scale)
resized_image = tf.image.resize_images(image, [new_height, new_width]) # (2)
return resized_image, label
# Preprocessing (for training)
# (3) Take a random 224x224 crop to the scaled image
# (4) Horizontally flip the image with probability 1/2
# (5) Substract the per color mean `VGG_MEAN`
# Note: we don't normalize the data here, as VGG was trained without normalization
def training_preprocess(image, label):
crop_image = tf.random_crop(image, [224, 224, 3]) # (3)
flip_image = tf.image.random_flip_left_right(crop_image) # (4)
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
centered_image = flip_image - means # (5)
return centered_image, label
# Preprocessing (for validation)
# (3) Take a central 224x224 crop to the scaled image
# (4) Substract the per color mean `VGG_MEAN`
# Note: we don't normalize the data here, as VGG was trained without normalization
def val_preprocess(image, label):
crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224) # (3)
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
centered_image = crop_image - means # (4)
return centered_image, label
# ----------------------------------------------------------------------
# DATASET CREATION using tf.contrib.data.Dataset
# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/data
# The tf.contrib.data.Dataset framework uses queues in the background to feed in
# data to the model.
# We initialize the dataset with a list of filenames and labels, and then apply
# the preprocessing functions described above.
# Behind the scenes, queues will load the filenames, preprocess them with multiple
# threads and apply the preprocessing in parallel, and then batch the data
# Training dataset
train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
train_dataset = train_dataset.map(_parse_function,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
train_dataset = train_dataset.map(training_preprocess,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
train_dataset = train_dataset.shuffle(buffer_size=10000) # don't forget to shuffle
batched_train_dataset = train_dataset.batch(args.batch_size)
# Validation dataset
val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels))
val_dataset = val_dataset.map(_parse_function,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
val_dataset = val_dataset.map(val_preprocess,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
batched_val_dataset = val_dataset.batch(args.batch_size)
# Now we define an iterator that can operator on either dataset.
# The iterator can be reinitialized by calling:
# - sess.run(train_init_op) for 1 epoch on the training set
# - sess.run(val_init_op) for 1 epoch on the valiation set
# Once this is done, we don't need to feed any value for images and labels
# as they are automatically pulled out from the iterator queues.
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `train_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types,
batched_train_dataset.output_shapes)
images, labels = iterator.get_next()
train_init_op = iterator.make_initializer(batched_train_dataset)
val_init_op = iterator.make_initializer(batched_val_dataset)
# Indicates whether we are in training or in test mode
is_training = tf.placeholder(tf.bool)
# ---------------------------------------------------------------------
# Now that we have set up the data, it's time to set up the model.
# For this example, we'll use VGG-16 pretrained on ImageNet. We will remove the
# last fully connected layer (fc8) and replace it with our own, with an
# output size num_classes=8
# We will first train the last layer for a few epochs.
# Then we will train the entire model on our dataset for a few epochs.
# Get the pretrained model, specifying the num_classes argument to create a new
# fully connected replacing the last one, called "vgg_16/fc8"
# Each model has a different architecture, so "vgg_16/fc8" will change in another model.
# Here, logits gives us directly the predicted scores we wanted from the images.
# We pass a scope to initialize "vgg_16/fc8" weights with he_initializer
vgg = tf.contrib.slim.nets.vgg
with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=args.weight_decay)):
logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training,
dropout_keep_prob=args.dropout_keep_prob)
# Specify where the model checkpoint is (pretrained weights).
model_path = args.model_path
assert(os.path.isfile(model_path))
# Restore only the layers up to fc7 (included)
# Calling function `init_fn(sess)` will load all the pretrained weights.
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8'])
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)
# Initialization operation from scratch for the new "fc8" layers
# `get_variables` will only return the variables whose name starts with the given pattern
fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
fc8_init = tf.variables_initializer(fc8_variables)
# ---------------------------------------------------------------------
# Using tf.losses, any loss is added to the tf.GraphKeys.LOSSES collection
# We can then call the total loss easily
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
loss = tf.losses.get_total_loss()
# First we want to train only the reinitialized last layer fc8 for a few epochs.
# We run minimize the loss only with respect to the fc8 variables (weight and bias).
fc8_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate1)
fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables)
# Then we want to finetune the entire model for a few epochs.
# We run minimize the loss only with respect to all the variables.
full_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate2)
full_train_op = full_optimizer.minimize(loss)
# Evaluation metrics
prediction = tf.to_int32(tf.argmax(logits, 1))
correct_prediction = tf.equal(prediction, labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.get_default_graph().finalize()
# --------------------------------------------------------------------------
# Now that we have built the graph and finalized it, we define the session.
# The session is the interface to *run* the computational graph.
# We can call our training operations with `sess.run(train_op)` for instance
with tf.Session(graph=graph) as sess:
init_fn(sess) # load the pretrained weights
sess.run(fc8_init) # initialize the new fc8 layer
# Update only the last layer for a few epochs.
for epoch in range(args.num_epochs1):
# Run an epoch over the training data.
print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs1))
# Here we initialize the iterator with the training set.
# This means that we can go through an entire epoch until the iterator becomes empty.
sess.run(train_init_op)
while True:
try:
_ = sess.run(fc8_train_op, {is_training: True})
except tf.errors.OutOfRangeError:
break
# Check accuracy on the train and val sets every epoch.
train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
print('Train accuracy: %f' % train_acc)
print('Val accuracy: %f\n' % val_acc)
# Train the entire model for a few more epochs, continuing with the *same* weights.
for epoch in range(args.num_epochs2):
print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2))
sess.run(train_init_op)
while True:
try:
_ = sess.run(full_train_op, {is_training: True})
except tf.errors.OutOfRangeError:
break
# Check accuracy on the train and val sets every epoch
train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
print('Train accuracy: %f' % train_acc)
print('Val accuracy: %f\n' % val_acc)
if __name__ == '__main__':
args = parser.parse_args()
main(args)
@tianzq

This comment has been minimized.

Copy link

commented Sep 9, 2017

It's my first time to use Tensorflow. Could you provide me a deploy code for testing? Thanks in advance.

@IgorMihajlovic

This comment has been minimized.

Copy link

commented Sep 15, 2017

What is the accuracy that we should expect? With these hyperparameters the second phase doesnt seem to influence much as both train and val acc remains around 85%.

@eggie5

This comment has been minimized.

Copy link

commented Sep 22, 2017

on line 133 you note "# Standard preprocessing for VGG on ImageNet taken from here: https://github.com/tensorflow/models/blob/master/slim/preprocessing/vgg_preprocessing.py"

Why do you implement this yourself and not just use vgg_preprocessing.preprocess_image from the aforementioned library vgg_preprocessing.py?

@simo23

This comment has been minimized.

Copy link

commented Sep 28, 2017

Hi omoindrot, thanks for this very useful code!

I noticed that this code is quite fast during the training steps but gets very slow during the check_accuracy function. For this reason I check the accuracy operation which is on the training dataset (on the batch fed into the optimizer) to plot the training accuracy during iterations.

My question: is there a way to compute the accuracy on the validation dataset in the same training for loop? In other words can we evaluate correct_predictions on a batch coming from the validation dataset iterator just like we do for the training dataset but while being in the training for loop? It would be great to have its value ( sampled on a batch) at each iteration rather than at each epoch to see how things are going on.

I think the issue is to use two iterators at the same time which evaluate the same operation (which is correct_predictions). If this would be possible In this way it will be very easy to control in a better way how the training is going in terms of accuracy on validation and overfitting.

Another question: which is the correct way to use another optimizer rather than SGD? I tried using this script with SGD+momentum but on the change between the two phases the training goes crazy. Do you know why?

Thanks, Andrea

@howardyclo

This comment has been minimized.

Copy link

commented Oct 1, 2017

On the line 322, it should be print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2)) <---- ("epoch1" to "epoch2")

@omoindrot

This comment has been minimized.

Copy link
Owner Author

commented Oct 23, 2017

@IgorMihajlovic : I didn't try a lot of hyperparameters. In this case, because ImageNet and the small animal dataset we use are very close, fine-tuning might not be very useful hence the low gain in accuracy. If the new dataset used was much more different (ex: medical images), maybe fine-tuning would give a bigger boost in accuracy.

@eggie5 : I wanted the code to be self contained, whereas the vgg_preprocessing.py file already has more than 300 lines. The goal was also to show some techniques from tf.contrib.data.

@simo23: The check_accuracy is super slow because we iterate over the whole training set + validation set at each epoch, which might not be necessary. I agree that using the predictions from training to have a running average of the training accuracy is better. For the validation accuracy, I'm not sure you can have both datasets running at the same time. One idea could be to evaluate the validation accuracy more often (ex: 10 times each epoch) and on a smaller subset of the validation set.
If you use momentum, you might need to re-initialize the local variables before fine-tuning? Maybe keeping the cached momentum values from step 1 makes the beginning of step 2 go crazy.

@YuChunLOL: thanks, I corrected it.

@jinangela

This comment has been minimized.

Copy link

commented Nov 7, 2017

Hi omoindrot, thanks for sharing this piece of code, I found it very helpful!
I have one little question though: when I tried to print out the labels for train set and val set, I found that the order of the labels are different. In my case, I have five categories A,B,C,D,E, and the train labels are A:3, B:1, C:2, D:0, E:4, while the val labels are A:1, B:0, C:3, D:4, E:2. Is this expected?
Btw, I found this problem when I tried to plot some predictions, and I found that the performance is very poor on my test dataset(around 20% accuracy) even though my validation accuracy is around 75%. Then after I printed out the label mappings I realized that if I use train labels the test accuracy would be close to 80%, but if I use val labels, the accuracy is only about 20%......

@jfilter

This comment has been minimized.

Copy link

commented Nov 7, 2017

The script doesn't save the learned model to disk, right? Wouldn't it make sense to add it?

@munir01

This comment has been minimized.

Copy link

commented Nov 8, 2017

Hello @omoindrot, Your code is very helpful! Thanks a lot for sharing it!
I agree with @jfilter that saving and restoring the model would be useful as well.

I added this code to save the model:
saver = tf.train.Saver()

And then after the training is over:
saver.save(sess, 'abc-model')

It does save the model. But when I restore the model in a separate script, I am having troubles with setting up iterators to feed test data on the restored model. If you could add the functionality of saving the model in your code and have another script to restore it to make use of the model in a different dataset, that would be extremely helpful! Thanks!

@munir01

This comment has been minimized.

Copy link

commented Nov 17, 2017

It is fixed. Never mind.

@xiaofanglegoc

This comment has been minimized.

Copy link

commented Nov 30, 2017

@munir01: Hi,
I have also encountered the same problem as yours. When I restore the saved model. the last layer's name is not saved in the restored weights.
Could you please share how you solved the problem?

@eggie5

This comment has been minimized.

Copy link

commented Dec 11, 2017

@omoindrot curious if you've had luck running this on the inception slim implementation?

@omoindrot

This comment has been minimized.

Copy link
Owner Author

commented Dec 23, 2017

@jinangela: fixed it with labels.sort()
@eggie5: I didn't try but it should work in the same way

@tensorAI

This comment has been minimized.

Copy link

commented Jan 8, 2018

is there any way to make allow_smaller_batch=True equivalent implementation so that we can get rid of except part in a for loop of epoch?

@dreamflasher

This comment has been minimized.

Copy link

commented Mar 1, 2018

@munir01 Could you please post your solution?
@xiaofanglegoc did you manage to restore the saved model?

@yangjingyi

This comment has been minimized.

Copy link

commented Apr 4, 2018

I found this
https://stackoverflow.com/questions/38947658/tensorflow-saving-into-loading-a-graph-from-a-file
It said that "To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights." I think it may solve the problem of @dreamflasher.
But I still not find where to add "saver = tf.train.Saver()" and "saver.save(sess, 'abc-model')" to avoid mistake...

@ZhenzhuZheng

This comment has been minimized.

Copy link

commented Apr 5, 2018

I'm getting this error. Could anyone tell how to fix this?

Traceback (most recent call last):
  File "/Users/erica/Downloads/tensorflow_finetune_org.py", line 341, in <module>
    main(args)
  File "/Users/erica/Downloads/tensorflow_finetune_org.py", line 202, in main
    train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
  File "/Users/erica/Workspace/cs231n/env/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 473, in from_tensor_slices
    return TensorSliceDataset(tensors)
  File "/Users/erica/Workspace/cs231n/env/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 896, in __init__
    batch_dim = flat_tensors[0].get_shape()[0]
  File "/Users/erica/Workspace/cs231n/env/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py", line 500, in __getitem__
    return self._dims[key]
IndexError: list index out of range

Process finished with exit code 1
@dreamflasher

This comment has been minimized.

Copy link

commented Apr 7, 2018

@yangjingyi Thanks! I added "saver.save" it after the print of the accuracy scores and "saver = " after the creation of the session.

This is how you restore:

sess = tf.Session()
saver = tf.train.import_meta_graph('vgg-retrain.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
@rimitalahiri

This comment has been minimized.

Copy link

commented Jun 14, 2018

@ZhenzhuZheng could you solve the problem...I am getting same issue.....please help if you have got a solution to this.

@hspark84

This comment has been minimized.

Copy link

commented Jul 19, 2018

@rimitalahiri I solved the problem by putting tf.constant operation in front of the train_filenames, train_labels, val_filenames, and val_labels.
It may be wrong, but I can't see the error anymore.

@pradeepkr1303

This comment has been minimized.

Copy link

commented Aug 1, 2018

How to execute this script?how should i execute it in windows command prompt? Thanks in advance.

@anilesec

This comment has been minimized.

Copy link

commented Aug 16, 2018

Hello,
I am getting the following error, could anyone have any guess about it.

train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
AttributeError: module 'tensorflow.contrib.data' has no attribute 'Dataset'

Thank you!

@WASamK

This comment has been minimized.

Copy link

commented Sep 25, 2018

@AnilKumarES That is a tensorflow version problem. Try the version 1.2.

@sujeet-gandhi

This comment has been minimized.

Copy link

commented Oct 11, 2018

@omoindrot. Hi, How do u make someone understand that training the last layer for few epochs and then training the entire model for few epochs is a good finetuning approach !!

@thanh1985

This comment has been minimized.

Copy link

commented Oct 20, 2018

Hi, have someone tried to reconfigure the last 3 fc layers, not just only the last one?

@thanh1985

This comment has been minimized.

Copy link

commented Oct 20, 2018

@omoindrot do you have any idea about reconfiguring the last 3 fc layers?

@vishalghor

This comment has been minimized.

Copy link

commented Oct 28, 2018

@omoindrot how can i restore only the first 3 conv layer weigths from checkpoint file and add new layers post that for training the new layers again.kindly help me for the same.

@sandareka

This comment has been minimized.

Copy link

commented Dec 18, 2018

@omoindrot Thanks a lot for your script. I think you should sort 'unique_labels' list as well to make sure that training and validation datasets have the same order of labels (to have the same integer for a given label).

@HumberMe

This comment has been minimized.

Copy link

commented Dec 29, 2018

HI, does anyone have the problem "ResourceExhaustedError: OOM when allocating tensor with shape", when i reduce the size of batches, it sometimes works.

@shauryad15

This comment has been minimized.

Copy link

commented Mar 6, 2019

@omoindrot can you write a script to check the model's accuracy in the test images downloaded from the internet? The input and output layer of the graph is ambiguous.

@nael74

This comment has been minimized.

Copy link

commented May 6, 2019

Hello,
someone successed to make a prediction of 1 image after having restored the model ?
If so, could someone helps me because I am locked.
Thank you !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.