Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Last active January 26, 2024 22:14
Show Gist options
  • Star 82 You must be signed in to star a gist
  • Fork 36 You must be signed in to fork a gist
  • Save khanhnamle1994/e2ff59ddca93c0205ac4e566d40b5e88 to your computer and use it in GitHub Desktop.
Save khanhnamle1994/e2ff59ddca93c0205ac4e566d40b5e88 to your computer and use it in GitHub Desktop.
FCN - Full Code
#--------------------------
# USER-SPECIFIED DATA
#--------------------------
# Tune these parameters
num_classes = 2
image_shape = (160, 576)
EPOCHS = 40
BATCH_SIZE = 16
DROPOUT = 0.75
# Specify these directory paths
data_dir = './data'
runs_dir = './runs'
training_dir ='./data/data_road/training'
vgg_path = './data/vgg'
#--------------------------
# PLACEHOLDER TENSORS
#--------------------------
correct_label = tf.placeholder(tf.float32, [None, IMAGE_SHAPE[0], IMAGE_SHAPE[1], NUMBER_OF_CLASSES])
learning_rate = tf.placeholder(tf.float32)
keep_prob = tf.placeholder(tf.float32)
#--------------------------
# FUNCTIONS
#--------------------------
def load_vgg(sess, vgg_path):
# load the model and weights
model = tf.saved_model.loader.load(sess, ['vgg16'], vgg_path)
# Get Tensors to be returned from graph
graph = tf.get_default_graph()
image_input = graph.get_tensor_by_name('image_input:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')
layer3 = graph.get_tensor_by_name('layer3_out:0')
layer4 = graph.get_tensor_by_name('layer4_out:0')
layer7 = graph.get_tensor_by_name('layer7_out:0')
return image_input, keep_prob, layer3, layer4, layer7
def layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes):
# Use a shorter variable name for simplicity
layer3, layer4, layer7 = vgg_layer3_out, vgg_layer4_out, vgg_layer7_out
# Apply 1x1 convolution in place of fully connected layer
fcn8 = tf.layers.conv2d(layer7, filters=num_classes, kernel_size=1, name="fcn8")
# Upsample fcn8 with size depth=(4096?) to match size of layer 4 so that we can add skip connection with 4th layer
fcn9 = tf.layers.conv2d_transpose(fcn8, filters=layer4.get_shape().as_list()[-1],
kernel_size=4, strides=(2, 2), padding='SAME', name="fcn9")
# Add a skip connection between current final layer fcn8 and 4th layer
fcn9_skip_connected = tf.add(fcn9, layer4, name="fcn9_plus_vgg_layer4")
# Upsample again
fcn10 = tf.layers.conv2d_transpose(fcn9_skip_connected, filters=layer3.get_shape().as_list()[-1],
kernel_size=4, strides=(2, 2), padding='SAME', name="fcn10_conv2d")
# Add skip connection
fcn10_skip_connected = tf.add(fcn10, layer3, name="fcn10_plus_vgg_layer3")
# Upsample again
fcn11 = tf.layers.conv2d_transpose(fcn10_skip_connected, filters=num_classes,
kernel_size=16, strides=(8, 8), padding='SAME', name="fcn11")
return fcn11
def optimize(nn_last_layer, correct_label, learning_rate, num_classes):
# Reshape 4D tensors to 2D, each row represents a pixel, each column a class
logits = tf.reshape(nn_last_layer, (-1, num_classes), name="fcn_logits")
correct_label_reshaped = tf.reshape(correct_label, (-1, num_classes))
# Calculate distance from actual labels using cross entropy
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=correct_label_reshaped[:])
# Take mean for total loss
loss_op = tf.reduce_mean(cross_entropy, name="fcn_loss")
# The model implements this operation to find the weights/parameters that would yield correct pixel labels
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_op, name="fcn_train_op")
return logits, train_op, loss_op
def train_nn(sess, epochs, batch_size, get_batches_fn, train_op,
cross_entropy_loss, input_image,
correct_label, keep_prob, learning_rate):
keep_prob_value = 0.5
learning_rate_value = 0.001
for epoch in range(epochs):
# Create function to get batches
total_loss = 0
for X_batch, gt_batch in get_batches_fn(batch_size):
loss, _ = sess.run([cross_entropy_loss, train_op],
feed_dict={input_image: X_batch, correct_label: gt_batch,
keep_prob: keep_prob_value, learning_rate:learning_rate_value})
total_loss += loss;
print("EPOCH {} ...".format(epoch + 1))
print("Loss = {:.3f}".format(total_loss))
print()
def run():
# Download pretrained vgg model
helper.maybe_download_pretrained_vgg(data_dir)
# A function to get batches
get_batches_fn = helper.gen_batch_function(training_dir, image_shape)
with tf.Session() as session:
# Returns the three layers, keep probability and input layer from the vgg architecture
image_input, keep_prob, layer3, layer4, layer7 = load_vgg(session, vgg_path)
# The resulting network architecture from adding a decoder on top of the given vgg model
model_output = layers(layer3, layer4, layer7, num_classes)
# Returns the output logits, training operation and cost operation to be used
# - logits: each row represents a pixel, each column a class
# - train_op: function used to get the right parameters to the model to correctly label the pixels
# - cross_entropy_loss: function outputting the cost which we are minimizing, lower cost should yield higher accuracy
logits, train_op, cross_entropy_loss = optimize(model_output, correct_label, learning_rate, num_classes)
# Initialize all variables
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
print("Model build successful, starting training")
# Train the neural network
train_nn(session, EPOCHS, BATCH_SIZE, get_batches_fn,
train_op, cross_entropy_loss, image_input,
correct_label, keep_prob, learning_rate)
# Run the model with the test images and save each painted output image (roads painted green)
helper.save_inference_samples(runs_dir, data_dir, session, image_shape, logits, keep_prob, image_input)
print("All done!")
#--------------------------
# MAIN
#--------------------------
if __name__ == '__main__':
run()
@adekunleba
Copy link

Hi Khanhnamle, Please the challenge I have with Segmentation is representing the Image Data being used. How does the label sets look like and assuming you want to prepare your own label data, what's the approach and how does this fits into the FCN Architecture.

@jjasonhe
Copy link

Hi! Thanks for the Medium write-up; I'm trying to work off this, but when I do the tf.saved_model.loader.load(), it returns an error about missing pbtxt or pb file.

What files did you download from http://www.cs.toronto.edu/~frossard/post/vgg16/?
I only found vgg.py and the weights in an npz.

@nourihilscher
Copy link

@jjasonhe He used one of the Udacity car challenges as inspiration (he basically copied the code). The model used by Udacity including your missing .pb file can be downloaded here: Download vgg.zip

@crazysal
Copy link

Can someone share the helper file. How do I plot back the predicted masks in inference ?

@aaronlelevier
Copy link

@crazysal this Kaggle kernal has example code for plotting a predicted mask over an image:

https://www.kaggle.com/phoenigs/u-net-dropout-augmentation-stratification

@zhengrui315
Copy link

do you also train vgg variables?

@JhonatanEstabile
Copy link

could someone pass the url of the "helper" library that was used in this code?

@AnaRhisT94
Copy link

AnaRhisT94 commented Sep 15, 2018

correct_label = tf.placeholder(tf.float32, [None, image_shape[0], image_shape[1], num_classes]) Edit the code like that

@chaitanyajalluri
Copy link

during training vgg variables also gets trained????

@aymanhalabya
Copy link

aymanhalabya commented Oct 8, 2018

could someone pass the url of the "helper" library that was used in this code?

I know it's been a while since you asked this question, but I found the answer, and though I would share it here to help someone else.

He is solving the problem offered in this repo:
https://github.com/udacity/CarND-Semantic-Segmentation

another solution is here:
https://gist.github.com/lianyi/a5ba8d84f5b68401c2313b05e020b062

@Anbehner
Copy link

Great tutorial, anyaways we are facing a problem while opening a picture. It doesn’t happend for every picture but always for the same picture.

In get_batches_fn()

image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape)

OSError: cannot identify image file ‘./drive/My Drive/…./file.png

Do you have any solution for this ?
thank you

@infoweb-sd
Copy link

Hi, Thanks for this Great tutorial, I'm trying to run the file main.py, and it give me the error :
Traceback (most recent call last):

File "main.py", line 156, in
run()
File "main.py", line 117, in run
helper.maybe_download_pretrained_vgg(data_dir)
NameError: name 'helper' is not defined

I need help please!!!

Please how do i can run this code! Can someone share the helper file.

@NourO93
Copy link

NourO93 commented Jan 8, 2019

After training, my "run" file is empty, and I noticed that the loss in every epoch is 0. Can someone tell me why that might be

@getmlcode
Copy link

correct_label = tf.placeholder(tf.float32, [None, IMAGE_SHAPE[0], IMAGE_SHAPE[1], NUMBER_OF_CLASSES])
I guess IMAGE_SHAPE is a typo and same as image_shape ?

@paulgureghian
Copy link

In 'load_vgg' a 'model' variable is defined but never used. why ?

@Amithsai
Copy link

Hello
I tried executing your code but i'm getting allocation of X.. exceeded 10% of memeory error
please help

@Madhivarman
Copy link

@Amithsai This error is showing because your training data is not fitting to the memory. You can resolve this error by reducing the batch size you are giving to the network while training the network. Even though, you are facing this issue does your network still runs?

@chunchet-ng
Copy link

hi there, why must the image shape fixed as (160,576) i thought that the FCN can accept variable sized input?

@1n4001
Copy link

1n4001 commented Jun 17, 2019

@tomzzh
Copy link

tomzzh commented Jul 12, 2019

Are vgg parameters also trained?

@Jeff1996
Copy link

@chunchet-ng Did you solve your question?

@YCAyca
Copy link

YCAyca commented Aug 30, 2021

Is there anyone who faced with that problem : I train the model, there isnt any problem, the loss starts from nearly 68.000, then decrease directly to 0 at the 2. epoch. I train the model for 40 epoch and at the test phase, I don't see any segmentation result. The output is just the same with test images without any segmentation prediction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment