Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Created March 4, 2021 08:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwitiderrick/fdfaf04db793349e48ff9e8ea3c79b6f to your computer and use it in GitHub Desktop.
Save mwitiderrick/fdfaf04db793349e48ff9e8ea3c79b6f to your computer and use it in GitHub Desktop.
PyCharm Scientific Mode
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tensorflow.python.compiler.mlcompute import mlcompute
#%% activate eager execution.
disable_eager_execution()
mlcompute.set_mlc_device(device_name='gpu')
#%% loading the data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
#%% display one image.
image = X_train[785]
plt.imshow(image)
plt.show()
#%% data pre-processing
X_train = X_train / 255
X_test = X_test / 255
#%% model definition
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
#%% compiling the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
#%% training the model
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))
loss, accuracy = model.evaluate(X_test, y_test)
print('Accuracy on test dataset:', accuracy)
#%% running predictions
predictions = model.predict(X_test)
#%% testing on a single image
np.argmax(model.predict(X_test[60].reshape(1, 28, 28)), axis=-1)
#%% creating the metrics dataframe
metrics_df = pd.DataFrame(history.history)
metrics_df.to_csv("/Users/derrickmwiti/PycharmProjects/Prediction/data/metrics.csv")
#%% Visualizing the training and validation oss
metrics_df[["loss", "val_loss"]].plot()
plt.show()
metrics_df[["accuracy", "val_accuracy"]].plot()
plt.show()
#%% the end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment