PyCharm Scientific Mode
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.python.framework.ops import disable_eager_execution | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
from tensorflow.python.compiler.mlcompute import mlcompute | |
#%% activate eager execution. | |
disable_eager_execution() | |
mlcompute.set_mlc_device(device_name='gpu') | |
#%% loading the data | |
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() | |
#%% display one image. | |
image = X_train[785] | |
plt.imshow(image) | |
plt.show() | |
#%% data pre-processing | |
X_train = X_train / 255 | |
X_test = X_test / 255 | |
#%% model definition | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(512, activation='relu'), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(10, activation='softmax') | |
]) | |
#%% compiling the model | |
model.compile(optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
model.summary() | |
#%% training the model | |
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test)) | |
loss, accuracy = model.evaluate(X_test, y_test) | |
print('Accuracy on test dataset:', accuracy) | |
#%% running predictions | |
predictions = model.predict(X_test) | |
#%% testing on a single image | |
np.argmax(model.predict(X_test[60].reshape(1, 28, 28)), axis=-1) | |
#%% creating the metrics dataframe | |
metrics_df = pd.DataFrame(history.history) | |
metrics_df.to_csv("/Users/derrickmwiti/PycharmProjects/Prediction/data/metrics.csv") | |
#%% Visualizing the training and validation oss | |
metrics_df[["loss", "val_loss"]].plot() | |
plt.show() | |
metrics_df[["accuracy", "val_accuracy"]].plot() | |
plt.show() | |
#%% the end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment