Last active
April 24, 2018 03:21
-
-
Save khanhnamle1994/556348f3dc56d6f4d4a06a6b60756c32 to your computer and use it in GitHub Desktop.
FCN - Run a train a model and save output images resulting from the test image fed on the trained model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run(): | |
# Download pretrained vgg model | |
helper.maybe_download_pretrained_vgg(data_dir) | |
# A function to get batches | |
get_batches_fn = helper.gen_batch_function(training_dir, image_shape) | |
with tf.Session() as session: | |
# Returns the three layers, keep probability and input layer from the vgg architecture | |
image_input, keep_prob, layer3, layer4, layer7 = load_vgg(session, vgg_path) | |
# The resulting network architecture from adding a decoder on top of the given vgg model | |
model_output = layers(layer3, layer4, layer7, num_classes) | |
# Returns the output logits, training operation and cost operation to be used | |
# - logits: each row represents a pixel, each column a class | |
# - train_op: function used to get the right parameters to the model to correctly label the pixels | |
# - cross_entropy_loss: function outputting the cost which we are minimizing, lower cost should yield higher accuracy | |
logits, train_op, cross_entropy_loss = optimize(model_output, correct_label, learning_rate, num_classes) | |
# Initialize all variables | |
session.run(tf.global_variables_initializer()) | |
session.run(tf.local_variables_initializer()) | |
print("Model build successful, starting training") | |
# Train the neural network | |
train_nn(session, EPOCHS, BATCH_SIZE, get_batches_fn, | |
train_op, cross_entropy_loss, image_input, | |
correct_label, keep_prob, learning_rate) | |
# Run the model with the test images and save each painted output image (roads painted green) | |
helper.save_inference_samples(runs_dir, data_dir, session, image_shape, logits, keep_prob, image_input) | |
print("All done!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment