-
-
Save omoindrot/dedc857cdc0e680dfb1be99762990c9c to your computer and use it in GitHub Desktop.
""" | |
Example TensorFlow script for finetuning a VGG model on your own data. | |
Uses tf.contrib.data module which is in release v1.2 | |
Based on PyTorch example from Justin Johnson | |
(https://gist.github.com/jcjohnson/6e41e8512c17eae5da50aebef3378a4c) | |
Required packages: tensorflow (v1.2) | |
Download the weights trained on ImageNet for VGG: | |
``` | |
wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz | |
tar -xvf vgg_16_2016_08_28.tar.gz | |
rm vgg_16_2016_08_28.tar.gz | |
``` | |
For this example we will use a tiny dataset of images from the COCO dataset. | |
We have chosen eight types of animals (bear, bird, cat, dog, giraffe, horse, | |
sheep, and zebra); for each of these categories we have selected 100 training | |
images and 25 validation images from the COCO dataset. You can download and | |
unpack the data (176 MB) by running: | |
``` | |
wget cs231n.stanford.edu/coco-animals.zip | |
unzip coco-animals.zip | |
rm coco-animals.zip | |
``` | |
The training data is stored on disk; each category has its own folder on disk | |
and the images for that category are stored as .jpg files in the category folder. | |
In other words, the directory structure looks something like this: | |
coco-animals/ | |
train/ | |
bear/ | |
COCO_train2014_000000005785.jpg | |
COCO_train2014_000000015870.jpg | |
[...] | |
bird/ | |
cat/ | |
dog/ | |
giraffe/ | |
horse/ | |
sheep/ | |
zebra/ | |
val/ | |
bear/ | |
bird/ | |
cat/ | |
dog/ | |
giraffe/ | |
horse/ | |
sheep/ | |
zebra/ | |
""" | |
import argparse | |
import os | |
import tensorflow as tf | |
import tensorflow.contrib.slim as slim | |
import tensorflow.contrib.slim.nets | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--train_dir', default='coco-animals/train') | |
parser.add_argument('--val_dir', default='coco-animals/val') | |
parser.add_argument('--model_path', default='vgg_16.ckpt', type=str) | |
parser.add_argument('--batch_size', default=32, type=int) | |
parser.add_argument('--num_workers', default=4, type=int) | |
parser.add_argument('--num_epochs1', default=10, type=int) | |
parser.add_argument('--num_epochs2', default=10, type=int) | |
parser.add_argument('--learning_rate1', default=1e-3, type=float) | |
parser.add_argument('--learning_rate2', default=1e-5, type=float) | |
parser.add_argument('--dropout_keep_prob', default=0.5, type=float) | |
parser.add_argument('--weight_decay', default=5e-4, type=float) | |
VGG_MEAN = [123.68, 116.78, 103.94] | |
def list_images(directory): | |
""" | |
Get all the images and labels in directory/label/*.jpg | |
""" | |
labels = os.listdir(directory) | |
# Sort the labels so that training and validation get them in the same order | |
labels.sort() | |
files_and_labels = [] | |
for label in labels: | |
for f in os.listdir(os.path.join(directory, label)): | |
files_and_labels.append((os.path.join(directory, label, f), label)) | |
filenames, labels = zip(*files_and_labels) | |
filenames = list(filenames) | |
labels = list(labels) | |
unique_labels = list(set(labels)) | |
label_to_int = {} | |
for i, label in enumerate(unique_labels): | |
label_to_int[label] = i | |
labels = [label_to_int[l] for l in labels] | |
return filenames, labels | |
def check_accuracy(sess, correct_prediction, is_training, dataset_init_op): | |
""" | |
Check the accuracy of the model on either train or val (depending on dataset_init_op). | |
""" | |
# Initialize the correct dataset | |
sess.run(dataset_init_op) | |
num_correct, num_samples = 0, 0 | |
while True: | |
try: | |
correct_pred = sess.run(correct_prediction, {is_training: False}) | |
num_correct += correct_pred.sum() | |
num_samples += correct_pred.shape[0] | |
except tf.errors.OutOfRangeError: | |
break | |
# Return the fraction of datapoints that were correctly classified | |
acc = float(num_correct) / num_samples | |
return acc | |
def main(args): | |
# Get the list of filenames and corresponding list of labels for training et validation | |
train_filenames, train_labels = list_images(args.train_dir) | |
val_filenames, val_labels = list_images(args.val_dir) | |
assert set(train_labels) == set(val_labels),\ | |
"Train and val labels don't correspond:\n{}\n{}".format(set(train_labels), | |
set(val_labels)) | |
num_classes = len(set(train_labels)) | |
# -------------------------------------------------------------------------- | |
# In TensorFlow, you first want to define the computation graph with all the | |
# necessary operations: loss, training op, accuracy... | |
# Any tensor created in the `graph.as_default()` scope will be part of `graph` | |
graph = tf.Graph() | |
with graph.as_default(): | |
# Standard preprocessing for VGG on ImageNet taken from here: | |
# https://github.com/tensorflow/models/blob/master/research/slim/preprocessing/vgg_preprocessing.py | |
# Also see the VGG paper for more details: https://arxiv.org/pdf/1409.1556.pdf | |
# Preprocessing (for both training and validation): | |
# (1) Decode the image from jpg format | |
# (2) Resize the image so its smaller side is 256 pixels long | |
def _parse_function(filename, label): | |
image_string = tf.read_file(filename) | |
image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1) | |
image = tf.cast(image_decoded, tf.float32) | |
smallest_side = 256.0 | |
height, width = tf.shape(image)[0], tf.shape(image)[1] | |
height = tf.to_float(height) | |
width = tf.to_float(width) | |
scale = tf.cond(tf.greater(height, width), | |
lambda: smallest_side / width, | |
lambda: smallest_side / height) | |
new_height = tf.to_int32(height * scale) | |
new_width = tf.to_int32(width * scale) | |
resized_image = tf.image.resize_images(image, [new_height, new_width]) # (2) | |
return resized_image, label | |
# Preprocessing (for training) | |
# (3) Take a random 224x224 crop to the scaled image | |
# (4) Horizontally flip the image with probability 1/2 | |
# (5) Substract the per color mean `VGG_MEAN` | |
# Note: we don't normalize the data here, as VGG was trained without normalization | |
def training_preprocess(image, label): | |
crop_image = tf.random_crop(image, [224, 224, 3]) # (3) | |
flip_image = tf.image.random_flip_left_right(crop_image) # (4) | |
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3]) | |
centered_image = flip_image - means # (5) | |
return centered_image, label | |
# Preprocessing (for validation) | |
# (3) Take a central 224x224 crop to the scaled image | |
# (4) Substract the per color mean `VGG_MEAN` | |
# Note: we don't normalize the data here, as VGG was trained without normalization | |
def val_preprocess(image, label): | |
crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224) # (3) | |
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3]) | |
centered_image = crop_image - means # (4) | |
return centered_image, label | |
# ---------------------------------------------------------------------- | |
# DATASET CREATION using tf.contrib.data.Dataset | |
# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/data | |
# The tf.contrib.data.Dataset framework uses queues in the background to feed in | |
# data to the model. | |
# We initialize the dataset with a list of filenames and labels, and then apply | |
# the preprocessing functions described above. | |
# Behind the scenes, queues will load the filenames, preprocess them with multiple | |
# threads and apply the preprocessing in parallel, and then batch the data | |
# Training dataset | |
train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels)) | |
train_dataset = train_dataset.map(_parse_function, | |
num_threads=args.num_workers, output_buffer_size=args.batch_size) | |
train_dataset = train_dataset.map(training_preprocess, | |
num_threads=args.num_workers, output_buffer_size=args.batch_size) | |
train_dataset = train_dataset.shuffle(buffer_size=10000) # don't forget to shuffle | |
batched_train_dataset = train_dataset.batch(args.batch_size) | |
# Validation dataset | |
val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels)) | |
val_dataset = val_dataset.map(_parse_function, | |
num_threads=args.num_workers, output_buffer_size=args.batch_size) | |
val_dataset = val_dataset.map(val_preprocess, | |
num_threads=args.num_workers, output_buffer_size=args.batch_size) | |
batched_val_dataset = val_dataset.batch(args.batch_size) | |
# Now we define an iterator that can operator on either dataset. | |
# The iterator can be reinitialized by calling: | |
# - sess.run(train_init_op) for 1 epoch on the training set | |
# - sess.run(val_init_op) for 1 epoch on the valiation set | |
# Once this is done, we don't need to feed any value for images and labels | |
# as they are automatically pulled out from the iterator queues. | |
# A reinitializable iterator is defined by its structure. We could use the | |
# `output_types` and `output_shapes` properties of either `train_dataset` | |
# or `validation_dataset` here, because they are compatible. | |
iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types, | |
batched_train_dataset.output_shapes) | |
images, labels = iterator.get_next() | |
train_init_op = iterator.make_initializer(batched_train_dataset) | |
val_init_op = iterator.make_initializer(batched_val_dataset) | |
# Indicates whether we are in training or in test mode | |
is_training = tf.placeholder(tf.bool) | |
# --------------------------------------------------------------------- | |
# Now that we have set up the data, it's time to set up the model. | |
# For this example, we'll use VGG-16 pretrained on ImageNet. We will remove the | |
# last fully connected layer (fc8) and replace it with our own, with an | |
# output size num_classes=8 | |
# We will first train the last layer for a few epochs. | |
# Then we will train the entire model on our dataset for a few epochs. | |
# Get the pretrained model, specifying the num_classes argument to create a new | |
# fully connected replacing the last one, called "vgg_16/fc8" | |
# Each model has a different architecture, so "vgg_16/fc8" will change in another model. | |
# Here, logits gives us directly the predicted scores we wanted from the images. | |
# We pass a scope to initialize "vgg_16/fc8" weights with he_initializer | |
vgg = tf.contrib.slim.nets.vgg | |
with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=args.weight_decay)): | |
logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training, | |
dropout_keep_prob=args.dropout_keep_prob) | |
# Specify where the model checkpoint is (pretrained weights). | |
model_path = args.model_path | |
assert(os.path.isfile(model_path)) | |
# Restore only the layers up to fc7 (included) | |
# Calling function `init_fn(sess)` will load all the pretrained weights. | |
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8']) | |
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore) | |
# Initialization operation from scratch for the new "fc8" layers | |
# `get_variables` will only return the variables whose name starts with the given pattern | |
fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8') | |
fc8_init = tf.variables_initializer(fc8_variables) | |
# --------------------------------------------------------------------- | |
# Using tf.losses, any loss is added to the tf.GraphKeys.LOSSES collection | |
# We can then call the total loss easily | |
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) | |
loss = tf.losses.get_total_loss() | |
# First we want to train only the reinitialized last layer fc8 for a few epochs. | |
# We run minimize the loss only with respect to the fc8 variables (weight and bias). | |
fc8_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate1) | |
fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables) | |
# Then we want to finetune the entire model for a few epochs. | |
# We run minimize the loss only with respect to all the variables. | |
full_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate2) | |
full_train_op = full_optimizer.minimize(loss) | |
# Evaluation metrics | |
prediction = tf.to_int32(tf.argmax(logits, 1)) | |
correct_prediction = tf.equal(prediction, labels) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
tf.get_default_graph().finalize() | |
# -------------------------------------------------------------------------- | |
# Now that we have built the graph and finalized it, we define the session. | |
# The session is the interface to *run* the computational graph. | |
# We can call our training operations with `sess.run(train_op)` for instance | |
with tf.Session(graph=graph) as sess: | |
init_fn(sess) # load the pretrained weights | |
sess.run(fc8_init) # initialize the new fc8 layer | |
# Update only the last layer for a few epochs. | |
for epoch in range(args.num_epochs1): | |
# Run an epoch over the training data. | |
print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs1)) | |
# Here we initialize the iterator with the training set. | |
# This means that we can go through an entire epoch until the iterator becomes empty. | |
sess.run(train_init_op) | |
while True: | |
try: | |
_ = sess.run(fc8_train_op, {is_training: True}) | |
except tf.errors.OutOfRangeError: | |
break | |
# Check accuracy on the train and val sets every epoch. | |
train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op) | |
val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op) | |
print('Train accuracy: %f' % train_acc) | |
print('Val accuracy: %f\n' % val_acc) | |
# Train the entire model for a few more epochs, continuing with the *same* weights. | |
for epoch in range(args.num_epochs2): | |
print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2)) | |
sess.run(train_init_op) | |
while True: | |
try: | |
_ = sess.run(full_train_op, {is_training: True}) | |
except tf.errors.OutOfRangeError: | |
break | |
# Check accuracy on the train and val sets every epoch | |
train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op) | |
val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op) | |
print('Train accuracy: %f' % train_acc) | |
print('Val accuracy: %f\n' % val_acc) | |
if __name__ == '__main__': | |
args = parser.parse_args() | |
main(args) |
@IgorMihajlovic : I didn't try a lot of hyperparameters. In this case, because ImageNet and the small animal dataset we use are very close, fine-tuning might not be very useful hence the low gain in accuracy. If the new dataset used was much more different (ex: medical images), maybe fine-tuning would give a bigger boost in accuracy.
@eggie5 : I wanted the code to be self contained, whereas the vgg_preprocessing.py
file already has more than 300 lines. The goal was also to show some techniques from tf.contrib.data
.
@simo23: The check_accuracy
is super slow because we iterate over the whole training set + validation set at each epoch, which might not be necessary. I agree that using the predictions from training to have a running average of the training accuracy is better. For the validation accuracy, I'm not sure you can have both datasets running at the same time. One idea could be to evaluate the validation accuracy more often (ex: 10 times each epoch) and on a smaller subset of the validation set.
If you use momentum, you might need to re-initialize the local variables before fine-tuning? Maybe keeping the cached momentum values from step 1 makes the beginning of step 2 go crazy.
@YuChunLOL: thanks, I corrected it.
Hi omoindrot, thanks for sharing this piece of code, I found it very helpful!
I have one little question though: when I tried to print out the labels for train set and val set, I found that the order of the labels are different. In my case, I have five categories A,B,C,D,E, and the train labels are A:3, B:1, C:2, D:0, E:4, while the val labels are A:1, B:0, C:3, D:4, E:2. Is this expected?
Btw, I found this problem when I tried to plot some predictions, and I found that the performance is very poor on my test dataset(around 20% accuracy) even though my validation accuracy is around 75%. Then after I printed out the label mappings I realized that if I use train labels the test accuracy would be close to 80%, but if I use val labels, the accuracy is only about 20%......
The script doesn't save the learned model to disk, right? Wouldn't it make sense to add it?
Hello @omoindrot, Your code is very helpful! Thanks a lot for sharing it!
I agree with @jfilter that saving and restoring the model would be useful as well.
I added this code to save the model:
saver = tf.train.Saver()
And then after the training is over:
saver.save(sess, 'abc-model')
It does save the model. But when I restore the model in a separate script, I am having troubles with setting up iterators to feed test data on the restored model. If you could add the functionality of saving the model in your code and have another script to restore it to make use of the model in a different dataset, that would be extremely helpful! Thanks!
It is fixed. Never mind.
@munir01: Hi,
I have also encountered the same problem as yours. When I restore the saved model. the last layer's name is not saved in the restored weights.
Could you please share how you solved the problem?
@omoindrot curious if you've had luck running this on the inception slim implementation?
@jinangela: fixed it with labels.sort()
@eggie5: I didn't try but it should work in the same way
is there any way to make allow_smaller_batch=True
equivalent implementation so that we can get rid of except
part in a for loop of epoch?
@munir01 Could you please post your solution?
@xiaofanglegoc did you manage to restore the saved model?
I found this
https://stackoverflow.com/questions/38947658/tensorflow-saving-into-loading-a-graph-from-a-file
It said that "To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights." I think it may solve the problem of @dreamflasher.
But I still not find where to add "saver = tf.train.Saver()" and "saver.save(sess, 'abc-model')" to avoid mistake...
I'm getting this error. Could anyone tell how to fix this?
Traceback (most recent call last):
File "/Users/erica/Downloads/tensorflow_finetune_org.py", line 341, in <module>
main(args)
File "/Users/erica/Downloads/tensorflow_finetune_org.py", line 202, in main
train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
File "/Users/erica/Workspace/cs231n/env/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 473, in from_tensor_slices
return TensorSliceDataset(tensors)
File "/Users/erica/Workspace/cs231n/env/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 896, in __init__
batch_dim = flat_tensors[0].get_shape()[0]
File "/Users/erica/Workspace/cs231n/env/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py", line 500, in __getitem__
return self._dims[key]
IndexError: list index out of range
Process finished with exit code 1
@yangjingyi Thanks! I added "saver.save" it after the print of the accuracy scores and "saver = " after the creation of the session.
This is how you restore:
sess = tf.Session()
saver = tf.train.import_meta_graph('vgg-retrain.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
@ZhenzhuZheng could you solve the problem...I am getting same issue.....please help if you have got a solution to this.
@rimitalahiri I solved the problem by putting tf.constant operation in front of the train_filenames, train_labels, val_filenames, and val_labels.
It may be wrong, but I can't see the error anymore.
How to execute this script?how should i execute it in windows command prompt? Thanks in advance.
Hello,
I am getting the following error, could anyone have any guess about it.
train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
AttributeError: module 'tensorflow.contrib.data' has no attribute 'Dataset'
Thank you!
@Anilkumares That is a tensorflow version problem. Try the version 1.2.
@omoindrot. Hi, How do u make someone understand that training the last layer for few epochs and then training the entire model for few epochs is a good finetuning approach !!
Hi, have someone tried to reconfigure the last 3 fc layers, not just only the last one?
@omoindrot do you have any idea about reconfiguring the last 3 fc layers?
@omoindrot how can i restore only the first 3 conv layer weigths from checkpoint file and add new layers post that for training the new layers again.kindly help me for the same.
@omoindrot Thanks a lot for your script. I think you should sort 'unique_labels' list as well to make sure that training and validation datasets have the same order of labels (to have the same integer for a given label).
HI, does anyone have the problem "ResourceExhaustedError: OOM when allocating tensor with shape", when i reduce the size of batches, it sometimes works.
@omoindrot can you write a script to check the model's accuracy in the test images downloaded from the internet? The input and output layer of the graph is ambiguous.
Hello,
someone successed to make a prediction of 1 image after having restored the model ?
If so, could someone helps me because I am locked.
Thank you !
@edwardnguyen1705, I want to configure the last few layers too. Do you have any idea or solution to do it?
@VyBui After that day, I switch to use Pytorch.
@edwardnguyen1705, hehe, good for you!
-
To whom it may concern, after a haft of a day, I finally found the one I was looking for( I apologize @omoindrot in advance. Because i am going to post some other github account here)
https://github.com/machrisaa/tensorflow-vgg/blob/master/vgg16.py -
If you want to get only extract_layers (conv4_3, conv4_2) like me:
Just use:
vgg = VGG16() conv4_3 = sess.run(vgg.conv4_3, feed_dict=feed_dict)
-
If you just want to fine tune:
Get the the fully connected layer by its name ('fc7': vgg.fc7)
Then fine tuning your self.
Cheers!
On the line 322, it should be print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2)) <---- ("epoch1" to "epoch2")