Created
March 4, 2021 08:33
-
-
Save mwitiderrick/fdfaf04db793349e48ff9e8ea3c79b6f to your computer and use it in GitHub Desktop.
PyCharm Scientific Mode
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.python.framework.ops import disable_eager_execution | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
from tensorflow.python.compiler.mlcompute import mlcompute | |
#%% activate eager execution. | |
disable_eager_execution() | |
mlcompute.set_mlc_device(device_name='gpu') | |
#%% loading the data | |
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() | |
#%% display one image. | |
image = X_train[785] | |
plt.imshow(image) | |
plt.show() | |
#%% data pre-processing | |
X_train = X_train / 255 | |
X_test = X_test / 255 | |
#%% model definition | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(512, activation='relu'), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(10, activation='softmax') | |
]) | |
#%% compiling the model | |
model.compile(optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
model.summary() | |
#%% training the model | |
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test)) | |
loss, accuracy = model.evaluate(X_test, y_test) | |
print('Accuracy on test dataset:', accuracy) | |
#%% running predictions | |
predictions = model.predict(X_test) | |
#%% testing on a single image | |
np.argmax(model.predict(X_test[60].reshape(1, 28, 28)), axis=-1) | |
#%% creating the metrics dataframe | |
metrics_df = pd.DataFrame(history.history) | |
metrics_df.to_csv("/Users/derrickmwiti/PycharmProjects/Prediction/data/metrics.csv") | |
#%% Visualizing the training and validation oss | |
metrics_df[["loss", "val_loss"]].plot() | |
plt.show() | |
metrics_df[["accuracy", "val_accuracy"]].plot() | |
plt.show() | |
#%% the end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment