Skip to content

Instantly share code, notes, and snippets.

@alstat alstat/low-tf-15.py Secret
Created Oct 31, 2019

Embed
What would you like to do?
The following codes are used for generating the plots in this article: .
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(rc = {
"axes.facecolor":"#191c20",
"figure.facecolor":"#191c20",
"grid.color":"#4c4d4f",
"text.color":"#838587",
'xtick.color': "#838587",
'ytick.color': "#838587",
'grid.linestyle': 'dotted',
'axes.spines.bottom': False,
'axes.spines.left': False,
'axes.spines.right': False,
'axes.spines.top': False,
'axes.labelcolor': "#838587"
})
from matplotlib.lines import Line2D
nml1 = pd.read_csv("../tf2_output_normal_initializer_batch_size_1.csv")
nml2 = pd.read_csv("../tf2_output_normal_initializer_batch_size_3.csv")
f, a = plt.subplots(1, 2, sharey = True, figsize = (10, 5))
sns.lineplot(x = "Unnamed: 0", y = "trn_loss", color = "#FF6F01", data = nml1.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[0])
sns.lineplot(x = "Unnamed: 0", y = "tst_loss", color = "#FFB204", data = nml1.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[0])
a[0].set(xlabel = "\nEpoch")
a[0].set(title = "One Batch\n")
a[0].set(ylabel = "Loss\n")
a[0].set(ylim = [-0.2, 6.2])
a[0].set_yticks([0, 2, 4, 6])
a[0].set_xticks([0, 250, 500])
sns.lineplot(x = "Unnamed: 0", y = "trn_loss", color = "#FF6F01", data = nml2.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[1])
sns.lineplot(x = "Unnamed: 0", y = "tst_loss", color = "#FFB204", data = nml2.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[1])
a[1].set(xlabel = "\nEpoch")
a[1].set(title = "Three Minibatches\n")
a[1].set_xticks([0, 250, 500])
plt.legend(
handles = [
Line2D([1], [1], color = "#FF6F01", lw = 2, label = "Training"),
Line2D([1], [1], color = "#FFB204", lw = 2, label = "Testing"),
]
)
plt.text(380, 0.6, "Stop\nTraining\nHere!", horizontalalignment='center', size='small', color="#838587")
plt.tight_layout()
plt.show()
f, a = plt.subplots(1, 2, sharey = True, figsize = (10, 5))
sns.lineplot(x = "Unnamed: 0", y = "trn_accy", color = "#FF6F01", data = nml1.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[0])
sns.lineplot(x = "Unnamed: 0", y = "tst_accy", color = "#FFB204", data = nml1.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[0])
a[0].set(xlabel = "\nEpoch")
a[0].set(title = "One Batch\n")
a[0].set(ylabel = "Accuracy\n")
a[0].set(ylim = [0.19, 1.01])
a[0].set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
a[0].set_xticks([0, 250, 500])
sns.lineplot(x = "Unnamed: 0", y = "trn_accy", color = "#FF6F01", data = nml2.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[1])
sns.lineplot(x = "Unnamed: 0", y = "tst_accy", color = "#FFB204", data = nml2.apply(lambda x: x + 1 if x.name == "Unnamed: 0" else x), ax = a[1])
a[1].set(xlabel = "\nEpoch")
a[1].set(title = "Three Minibatches\n")
a[1].set_xticks([0, 250, 500])
plt.legend(
handles = [
Line2D([1], [1], color = "#FF6F01", lw = 2, label = "Training"),
Line2D([1], [1], color = "#FFB204", lw = 2, label = "Testing"),
]
)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.