Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jakelevi1996/7bd6db101c1de80d32d4c4a0365df210 to your computer and use it in GitHub Desktop.
Save jakelevi1996/7bd6db101c1de80d32d4c4a0365df210 to your computer and use it in GitHub Desktop.
Self-contained, single-script example of training a neural network in TensorFlow

Self-contained, single-script example of training a neural network in TensorFlow

Shown below is the code for a self-contained, single-script example of a TensorFlow program; the steps it goes through are as follows:

  1. create_data_set():
    1. Create a randomly generated data-set, consisting of a training set, and a uniform grid for evaluating the model
  2. train_and_eval():
    1. Define a model
    2. Train the model on the training data (while keeping a log of summary operations)
    3. Save the trained model
    4. Restore the trained model, and evaluate it on the evaluation data
  3. plot_results():
    1. Plot the results
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def classification_function(x):
    """This is the function we are going to learn.
    In this case, it is a 2D circle.
    """
    y = np.array(1 * (x[:,0]**2 + x[:,1]**2) < 1).reshape(-1, 1)
    return y

def create_data_set(n_points=1000):
    """Create a randomly generated training set,
    and a uniform grid for evaluating the model.
    """
    # Create training data set
    x_train = np.random.randn(n_points,2)
    y_train = classification_function(x_train)

    # Create grid for evaluation
    x_array = np.linspace(-4, 4, 100)
    xx0, xx1 = np.meshgrid(x_array, x_array)
    x_eval = np.concatenate((xx0.reshape(-1,1), xx1.reshape(-1,1)), axis=1)

    return x_train, y_train, x_eval, xx0, xx1

def display_loss_val(epoch, loss_val):
    """Print to the console the current epoch and loss for the network."""
    print("Epoch: {:<8} | Loss: {:<.6f}".format(epoch, loss_val))

def train_and_eval(
    x_train, y_train, x_eval,
    num_epochs=3000,
    learning_rate=0.01,
    num_hidden_units=5,
    print_every=1000,
    logdir="summaries",
    savedir="models/saved_model"
):
    """Define a model, train it on some data, save the trained model,
    restore it, and evaluate it on some new data.
    """
    
    # Define network
    x = tf.placeholder(dtype=tf.float32, shape=(None,2))
    hidden_layer = tf.layers.dense(
        inputs=x, units=num_hidden_units,
        activation=tf.tanh, name="Hidden_layer"
    )
    logits = tf.layers.dense(
        inputs=hidden_layer, units=1, name="Output_layer"
    )
    y = tf.sigmoid(logits)

    # Define loss and optimiser
    loss_op = tf.losses.sigmoid_cross_entropy(y_train, logits)
    optimiser = tf.train.AdamOptimizer(learning_rate)
    train_op = optimiser.minimize(loss_op)

    # Create Saver object for saving
    saver = tf.train.Saver()

    # Create Operation for initialising variables
    init_op = tf.global_variables_initializer()

    # Create summaries, for visualising in Tensorboard
    tf.summary.scalar("Loss", loss_op)
    tf.summary.histogram("Logit values", logits)
    for grad, var in optimiser.compute_gradients(loss_op):
        tf.summary.histogram("Variables/" + var.name, var)
        tf.summary.histogram("Gradients/" + var.name, grad)
    summary_op = tf.summary.merge_all()

    # Train and save the model
    print("Starting TensorFlow Session...")
    with tf.Session() as sess:
        writer = tf.summary.FileWriter(logdir, sess.graph)
        sess.run(init_op)
        # Training loop:
        for epoch in range(num_epochs):
            # Run the graph, summaries and training op
            loss_val, summary_val, _ = sess.run(
                (loss_op, summary_op, train_op), feed_dict={x: x_train}
            )
            # Add summary to Tensorboard
            writer.add_summary(summary_val, epoch)
            # Display progress every few epochs
            if epoch % print_every == 0:
                display_loss_val(epoch, loss_val)
        # Evaluate final loss
        loss_val = sess.run(loss_op, feed_dict={x: x_train})
        display_loss_val(num_epochs, loss_val)
        
        # Save model
        save_path = saver.save(sess, savedir)
    
    # Restore and evaluate the model
    with tf.Session() as eval_sess:
        print("Restoring model...")
        saver.restore(eval_sess, save_path)
        y_eval = y.eval(feed_dict={x: x_eval})
    
    return y_eval

def plot_results(
    x_train, y_train, xx0, xx1, yy, filename="classification results.png"
):
    # Plot training data as a 2D scatter plot
    plt.plot(
        x_train[y_train[:,0]==0, 0], x_train[y_train[:,0]==0, 1], 'bo',
        x_train[y_train[:,0]==1, 0], x_train[y_train[:,0]==1, 1], 'ro',
        alpha=.1
    )
    # Plot model evaluations as a contour plot
    plt.contour(xx0, xx1, yy, [.2, .4, .6, .8], cmap='bwr')
    # Neaten up the plot, save, and show
    plt.grid(True)
    plt.axis('equal')
    plt.savefig(filename)

if __name__ == "__main__":
    # Set the random seed
    seed = 0
    np.random.seed(seed)
    tf.set_random_seed(seed)
    # Create data set
    x_train, y_train, x_eval, xx0, xx1 = create_data_set()
    # Train and evaluate the model
    y_eval = train_and_eval(x_train, y_train, x_eval)
    # Plot the results
    yy = y_eval.reshape(xx0.shape)
    plot_results(x_train, y_train, xx0, xx1, yy)

The resulting classification boundaries from running the program with the given random seed are displayed in the image results.png. A screenshot from TensorBoard (which can be run by navigating to the summaries folder, opening a terminal window, running the command tensorboard --logdir ., and navigating to the subsequently displayed URL in a web browser) is shown in the image tensorboard.png. The corresponding console output is shown below:

Starting TensorFlow Session...
Epoch: 0        | Loss: 0.698640
Epoch: 1000     | Loss: 0.086445
Epoch: 2000     | Loss: 0.057700
Epoch: 3000     | Loss: 0.043505
Restoring model...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment