Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active October 17, 2016 16:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save gaphex/6b2062b995939280efd19c6f91edd8a5 to your computer and use it in GitHub Desktop.
Save gaphex/6b2062b995939280efd19c6f91edd8a5 to your computer and use it in GitHub Desktop.
import random
import matplotlib.pyplot as plt
from IPython import display
"""
IPython Display rc0
Try:
dsp = IDisplay()
dsp.test()
"""
class IDisplay(object):
def __init__(self,
vsize=8, hsize=12,
title='Training in progress',
xlabel='epoch',
ylabel='score'):
'''
Sets plot parameters such as:
hsize: horizontal plot size
vsize: vertical plot size
xlabel: x axis label
ylabel: y axis label
title: plot title
'''
self.colors = {}
self.xlabel = xlabel
self.ylabel = ylabel
self.title = title
self.hsize = hsize
self.vsize = vsize
def _generate_color(self, label):
'''
Generates color for new, unseen label and stores the pair.
'''
self.colors[label] = {'r': float(random.randrange(0,255))/255,
'g': float(random.randrange(0,255))/255,
'b': float(random.randrange(0,255))/255}
def regenerate_colors(self):
'''
Regenerates colors for seen labels
if you dont like the current ones.
'''
for label in self.colors:
self._generate_color(label)
def display(self, data_dict):
'''
Visualises 1-D data with line plots.
Every call to this method rewrites output from calling cell.
Accepts a data dictionary of following structure:
{'label_1': [],
'label_2': [],
...}
'''
display.clear_output(wait=True)
plt.figure(figsize=(self.hsize, self.vsize))
plt.xlabel(self.xlabel)
plt.ylabel(self.ylabel)
plt.title(self.title)
for label in sorted(data_dict.keys()):
if label not in self.colors:
self._generate_color(label)
plt.plot(data_dict[label], label=label,
linewidth=4,
color=(self.colors[label]['r'],
self.colors[label]['g'],
self.colors[label]['b']))
legend = plt.legend(loc='upper left', shadow=True)
plt.show()
def test(self, epochs=16):
'''
Performs a dummy run with randomly generated data
'''
data = {'popugai_1':[], 'popugai_2':[]}
for epoch in range(epochs):
for i, k in enumerate(data.keys()):
data[k].append(random.randint((i+1)*4, (i+1)*16))
self.display(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment