Skip to content

Instantly share code, notes, and snippets.

@bearpaw
Created March 11, 2018 01:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bearpaw/d82ad81cade2bba4003afd4130d4b7b7 to your computer and use it in GitHub Desktop.
Save bearpaw/d82ad81cade2bba4003afd4130d4b7b7 to your computer and use it in GitHub Desktop.
Draw tensorflow log with matplot lib
import numpy as np
from tensorboard.backend.event_processing import event_accumulator as ea
import matplotlib as mpl
import matplotlib.pyplot as plt
def plot_tensorflow_log(path):
# Loading too much data is slow...
tf_size_guidance = {
'compressedHistograms': 10,
'images': 0,
'scalars': 10 * 10**6,
'histograms': 1
}
event_acc = ea.EventAccumulator(path, tf_size_guidance)
event_acc.Reload()
# Show all tags in the log file
#print(event_acc.Tags())
# import pdb; pdb.set_trace()
training_accuracies = event_acc.Scalars("test/Episode_Length_19")
validation_accuracies = event_acc.Scalars("test/Episode_Length_19")
steps = len(training_accuracies)
print(steps)
x = np.arange(steps)
y = np.zeros([steps, 2])
for i in range(steps):
y[i, 0] = training_accuracies[i][2] # value
y[i, 1] = validation_accuracies[i][2]
plt.plot(x, y[:,0], label='training accuracy')
plt.plot(x, y[:,1], label='validation accuracy')
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.title("Training Progress")
plt.legend(loc='upper right', frameon=True)
plt.show()
if __name__ == '__main__':
log_file = "/home/wyang/code/rl/zeroshot/zeroshot/train_4act_450x300_kitchen_acnetv2/log/events.out.tfevents.1520655158.cdc-43"
plot_tensorflow_log(log_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment