Created
March 12, 2019 16:12
-
-
Save marcopeix/7ec7211be6fff19e0c5fcc265af7ac0d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model(X_train, Y_train, X_test, Y_test, learning_rate = 0.0001, | |
num_epochs = 1500, minibatch_size = 32, print_cost = True): | |
ops.reset_default_graph() # to be able to rerun the model without overwriting tf variables | |
tf.set_random_seed(1) # to keep consistent results | |
seed = 3 # to keep consistent results | |
(n_x, m) = X_train.shape # (n_x: input size, m : number of examples in the train set) | |
n_y = Y_train.shape[0] # n_y : output size | |
costs = [] # To keep track of the cost | |
# Create Placeholders of shape (n_x, n_y) | |
X, Y = create_placeholders(n_x, n_y) | |
# Initialize parameters | |
parameters = initialize_parameters() | |
# Forward propagation | |
Z3 = forward_propagation(X, parameters) | |
# Cost function | |
cost = compute_cost(Z3, Y) | |
# Backpropagation | |
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) | |
# Initialize all the variables | |
init = tf.global_variables_initializer() | |
# Start the session | |
with tf.Session() as sess: | |
# Run the initialization | |
sess.run(init) | |
# Do the training loop | |
for epoch in range(num_epochs): | |
epoch_cost = 0. | |
num_minibatches = int(m / minibatch_size) | |
seed = seed + 1 | |
minibatches = random_mini_batches(X_train, Y_train, minibatch_size, seed) | |
for minibatch in minibatches: | |
# Select a minibatch | |
(minibatch_X, minibatch_Y) = minibatch | |
_ , minibatch_cost = sess.run([optimizer, cost], feed_dict={X: minibatch_X, Y: minibatch_Y}) | |
epoch_cost += minibatch_cost / num_minibatches | |
# Print the cost every epoch | |
if print_cost == True and epoch % 100 == 0: | |
print ("Cost after epoch %i: %f" % (epoch, epoch_cost)) | |
if print_cost == True and epoch % 5 == 0: | |
costs.append(epoch_cost) | |
# plot the cost | |
plt.plot(np.squeeze(costs)) | |
plt.ylabel('cost') | |
plt.xlabel('iterations (per tens)') | |
plt.title("Learning rate =" + str(learning_rate)) | |
plt.show() | |
# lets save the parameters in a variable | |
parameters = sess.run(parameters) | |
print("Parameters have been trained!") | |
# Calculate the correct predictions | |
correct_prediction = tf.equal(tf.argmax(Z3), tf.argmax(Y)) | |
# Calculate accuracy on the test set | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) | |
print("Train Accuracy:", accuracy.eval({X: X_train, Y: Y_train})) | |
print("Test Accuracy:", accuracy.eval({X: X_test, Y: Y_test})) | |
return parameters |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment