Skip to content

Instantly share code, notes, and snippets.

@wwoods
Last active August 29, 2015 14:13
Show Gist options
  • Save wwoods/66acb7c5c01a5ccab126 to your computer and use it in GitHub Desktop.
Save wwoods/66acb7c5c01a5ccab126 to your computer and use it in GitHub Desktop.
"""A figure generator for slicing and displaying data. Example usage:
import pandas
data = pandas.DataFrame()
data["a"] = None
data["b"] = None
data.loc[0] = [ 1, 1 ]
data.loc[1] = [ 2, 2 ]
data.loc[2] = [ 3, 2 ]
maker = FigureMaker(data, "images/", imageExtension = "png")
maker.setLabel("a", "First column")
maker.setLabel("b", "Second column")
maker.setScale("b", "log")
with maker.new("a vs b") as fig:
fig.plotValue("a and b", "a", "b")
"""
import math
import os
import matplotlib.image
from matplotlib import pyplot
import PIL.Image
import numpy as np
# Should come from python-slugify!
from slugify import slugify
class _Figure(object):
styles = [ '-', '--', ':' ]
markers = [ ' ', 'd', '*', 'o', 's', 'v', '<' ]
SECOND_Y_COLOR = 'r'
def __init__(self, maker, figureName, axes, legendAnchor):
self.maker = maker
self.axes = axes
self.axes.locator_params(tight = True)
self._legendAnchor = legendAnchor
self.width = axes.figure.get_size_inches()[0] * axes.get_position().width
self.height = axes.figure.get_size_inches()[1] * axes.get_position().height
self._lineCount = 0
self._figureName = figureName
self._xlabel = None
self._ylabel = None
self._ylabel2 = None
self._yaxis2 = None
def __enter__(self):
return self
def __exit__(self, type, value, tb):
if type is None:
# Draw our legend
extraArtists = []
if self._legendAnchor is not None:
if isinstance(self._legendAnchor, (tuple, list)):
# Anchor!
l = self.axes.legend(loc = self._legendAnchor[0], bbox_to_anchor = self._legendAnchor[1:])
else:
l = self.axes.legend(loc = self._legendAnchor)
extraArtists.append(l)
# Update second y axis colors
if self._yaxis2 is not None:
for t in self._yaxis2.get_yticklabels():
t.set_color(self.SECOND_Y_COLOR)
# Finish rendering to file, THEN render to IPython with the title
if self.maker.imageDestFolder is not None:
slugName = slugify(self._figureName)
fname = os.path.join(self.maker.imageDestFolder, slugName
+ '.' + self.maker.imageExtension)
if os.path.lexists(fname):
raise ValueError("Cannot name two figures {} (slugged {})".format(
self._figureName, slugName))
matplotlib.pyplot.savefig(fname, dpi = 300, bbox_inches = 'tight',
bbox_extra_artists = extraArtists)
self.axes.set_title(self._figureName, fontsize = 24)
def plotImage(self, x, img, **kwargs):
"""Plot a column of image _paths_ as stagged images in a plot.
"""
def innerPlot(axes, dataSet, options):
xMin = dataSet[x].min()
xMax = dataSet[x].max()
def mapX(x):
"""Returns the imAxes.x coordinate for a data (axes.) x coordinate"""
return (x - xMin) / (xMax - xMin)
def unmapX(x):
"""Returns the axes.x coordinate for an imAxes.x coordinate"""
return xMin + x * (xMax - xMin)
def calcBounds(x, width):
"""Given an imAxes x and width, return the imAxes left and right."""
partLeft = x
return (x - width * partLeft, x + width * (1.0 - partLeft))
# Use the twin axis for our images, and the primary x axis for ticks
# (and possibly log scaling)
imAxes = axes.twiny()
# Original axes must be set AFTER twinning (true for y, anyway)
axes.set_xlim(xMin, xMax)
axes.set_ylim(0.0, 1.0)
axes.set_ymargin(0.0)
axes.get_yaxis().set_ticks([])
axes.plot([ xMin, xMax ], [ 0.0, 1.0 ], linestyle = ' ')
self._legendAnchor = None
# Now set up our twin axes
imAxes.set_xscale("linear")
imAxes.set_xlim(0.0, 1.0)
imAxes.set_xmargin(0.0)
imAxes.get_xaxis().set_ticks([])
# Gutter margin for arrows
imMargin = 0.01
imLowestY = imMargin
if axes.get_xscale() == "log":
def mapLogX(x):
return (math.log(x) - math.log(xMin)) / (math.log(xMax) - math.log(xMin))
mapX = mapLogX
def unmapLogX(x):
return math.exp(x * (math.log(xMax) - math.log(xMin)) + math.log(xMin))
unmapX = unmapLogX
# List of [ imData, imAxesAspect, dataX, imAxesY, imAxesWidth ]
images = []
# imAxes are bound on X between [0, 1] and Y between [0, 1]. So,
# aspect is same as figure width and height.
imAxesAspect = self.height / self.width
for coord, imPath in zip(dataSet[x].values, dataSet[img].values):
# Load and ensure RGBA, flip to correspond to matplotlib's coordinates
image = PIL.Image.open(imPath).convert("RGBA")
# nativeAspect is the aspect ratio of the image
nativeAspect = image.size[1] * 1.0 / image.size[0]
axesAspect = nativeAspect / imAxesAspect
images.append([ np.asarray(image), axesAspect, coord, 1, min(1.0, (1.0 - imMargin) / axesAspect) ])
# Now figure out positions.... this is slow, but we want them all to be the
# same scale, so...just stack them ideally.
baseScale = 1.0
while True:
ok = True
for i in range(len(images)):
# For each image, detect a "safe y" based on images before it
safeY = 1.0
myBoundsX = calcBounds(mapX(images[i][2]), images[i][4] * baseScale)
for j in range(i):
myBoundsY = (safeY - images[i][4] * baseScale * images[i][1], safeY)
boundsX = calcBounds(mapX(images[j][2]), images[j][4] * baseScale)
boundsY = (images[j][3] - images[j][4] * baseScale * images[j][1], images[j][3])
if myBoundsX[0] <= boundsX[1] and myBoundsX[1] >= boundsX[0]:
if myBoundsY[0] <= boundsY[1] and myBoundsY[1] >= boundsY[0]:
# Overlap! We can only move down, since we start at the top
safeY = boundsY[0] - imMargin
if safeY - images[i][4] * baseScale * images[i][1] < imLowestY:
ok = False
break
images[i][3] = safeY
if not ok:
break
if ok:
break
baseScale *= 0.9
# Now that we've calculated the display, render everything
for imData, axesAspect, dataX, imAxesY, imAxesWidth in images:
imX = mapX(dataX)
# Apply our new scaling
imAxesWidth *= baseScale
imYBottom = imAxesY - axesAspect * imAxesWidth
bounds = calcBounds(imX, imAxesWidth)
imAxes.imshow(imData, interpolation = 'nearest', aspect = 'auto',
extent = (bounds[0], bounds[1], imYBottom, imAxesY))
arrowX = (bounds[0] + bounds[1]) * 0.5
imAxes.annotate("",
xy = (imX, 0), xycoords = 'data',
xytext = (arrowX, imYBottom), textcoords = 'data',
arrowprops = dict(arrowstyle = "->", connectionstyle = "arc3"))
self._callPlot(x, img, innerPlot, label = None, **kwargs)
def plotError(self, label, x, y, y2, **kwargs):
self._callPlot(
x, y,
lambda axes, dataSet, options: axes.errorbar(dataSet[x], dataSet[y], dataSet[y2], capsize = 6, **options),
label, **kwargs)
def plotValue(self, label, x, y, **kwargs):
self._callPlot(x, y,
lambda axes, dataSet, options: axes.plot(dataSet[x], dataSet[y], **options),
label, **kwargs)
def _callPlot(self, xDataName, yDataName, plotMethod, label, **kwargs):
dataSet = self.maker.dataSet
for k, v in kwargs.iteritems():
if k.endswith("_lt"):
k = k[:-3]
dataSet = dataSet.loc[dataSet[k] < v]
elif k.endswith("_lte"):
k = k[:-4]
dataSet = dataSet.loc[dataSet[k] <= v]
elif k.endswith("_gt"):
k = k[:-3]
dataSet = dataSet.loc[dataSet[k] > v]
elif k.endswith("_gte"):
k = k[:-4]
dataSet = dataSet.loc[dataSet[k] >= v]
else:
dataSet = dataSet.loc[dataSet[k] == v]
markEvery = 5.0 / self.width * (3. + 1 * self._lineCount)
plotOptions = {
'label': label,
'linestyle': self.styles[self._lineCount % len(self.styles)],
'marker': self.markers[self._lineCount % len(self.markers)],
'markeredgecolor': 'none',
'markersize': 12,
'markevery': (markEvery, markEvery),
}
if plotOptions['marker'] == '*':
# stars are small
plotOptions['markersize'] *= 1.5
if len(dataSet) == 0:
return
axes = self.axes
# Enforce X and Y axis general properties BEFORE render, since e.g.
# image rendering needs to know if something is a log distribution.
xLabel = self.maker._labels.get(xDataName, xDataName)
if self._xlabel is None:
self._xlabel = xLabel
axes.set_xlabel(self._xlabel)
axes.set_xscale(self.maker._scales.get(xDataName, "linear"))
elif self._xlabel != xLabel:
raise ValueError("Using multiple x axes in one plot? Unwise!")
yLabel = self.maker._labels.get(yDataName, yDataName)
if self._ylabel is None:
self._ylabel = yLabel
axes.set_ylabel(self._ylabel)
axes.set_yscale(self.maker._scales.get(yDataName, "linear"))
elif self._ylabel != yLabel:
if self._ylabel2 == yLabel:
# Ok, plot to second
axes = self._yaxis2
pass
elif self._ylabel2 is None:
# Create a second axis and plot to that
self._ylabel2 = yLabel
self._yaxis2 = axes = axes.twinx()
axes.set_ylabel(self._ylabel2, color = self.SECOND_Y_COLOR)
axes.set_yscale(self.maker._scales.get(yDataName, "linear"))
else:
raise ValueError("Using multiple y axes in one plot? Unwise!")
plotMethod(axes, dataSet, plotOptions)
# --- Update xlimits and margin based on data
dataLimits = (axes.dataLim.xmin, axes.dataLim.xmax)
# Fully specified is OK, but a single None doesn't really work with matplotlib
limits = self.maker._limits.get(xDataName, (None, None))
if limits[0] is None or limits[1] is None:
axes.set_xlim(auto = True)
xMin, xMax = axes.get_xbound()
if limits[0] is not None:
xMin = limits[0]
if limits[1] is not None:
xMax = limits[1]
axes.set_xlim(xMin, xMax)
else:
axes.set_xlim(limits)
# Log scale margins are broken :(
xMargin = self.maker.getMargin(xDataName)
if axes.get_xscale() != "log":
axes.set_xmargin(xMargin)
else:
axes.set_xlim(auto = True)
xMin, xMax = axes.get_xbound()
logMin = math.log(xMin)
logMax = math.log(xMax)
diff = logMax - logMin
logMin -= diff * xMargin
logMax += diff * xMargin
axes.set_xlim(math.exp(logMin), math.exp(logMax))
# --- Update ylimits and margin based on data
dataLimits = (axes.dataLim.ymin, axes.dataLim.ymax)
# Fully specified is OK, but a single None doesn't really work with matplotlib
limits = self.maker._limits.get(yDataName, (None, None))
axes.set_ylim(limits[0] if limits[0] is not None else dataLimits[0],
limits[1] if limits[1] is not None else dataLimits[1])
# Log scale margins are broken :( :(
yMargin = self.maker.getMargin(yDataName)
if axes.get_yscale() != "log":
axes.set_ymargin(yMargin)
else:
logMin = math.log(dataLimits[0])
logMax = math.log(dataLimits[1])
diff = logMax - logMin
logMin -= diff * xMargin
logMax += diff * xMargin
axes.set_ylim(math.exp(logMin), math.exp(logMax))
self._lineCount += 1
class FigureMaker(object):
def __init__(self, dataSet, imageDestFolder = None,
imageExtension = 'png', defaultHeightAspect = 0.618):
self.dataSet = dataSet
self.imageDestFolder = imageDestFolder
self.imageExtension = imageExtension
# Clear out all old png files
if imageDestFolder is not None:
try:
os.makedirs(imageDestFolder)
except OSError, e:
# Already exists
if e.errno != 17:
raise
for fn in os.listdir(imageDestFolder):
os.remove(os.path.join(imageDestFolder, fn))
self._defaults = { 'heightAspect': defaultHeightAspect }
self._labels = {}
self._limits = {}
self._margins = {}
self._scales = {}
self._marginDefault = 0.00
self._unnamedCount = 0
def getMargin(self, dataName):
return self._margins.get(dataName, self._marginDefault)
def new(self, figureName = None, heightAspect = None, scale = 1.0, legendLoc = 0,
legendAnchor = None):
"""legendLoc - The loc argument for a legend in pyplot. 1 is upper right,
2 is lower right, etc. Defaults to 0 - best. Can also be strings 'upper right', etc.
Pass legendLoc = None to disable the legend.
legendAnchor - Overrides legendLoc. 3-tuple of (corner, x, y). E.g. ('upper right', 0, 0)
will position the legend's upper right corner at 0, 0"""
baseSize = 10.0 * scale
if figureName is None:
self._unnamedCount += 1
figureName = "Unnamed {}".format(self._unnamedCount)
if heightAspect is None:
heightAspect = self._defaults['heightAspect']
fig, axes = pyplot.subplots(1, 1,
figsize = (baseSize, baseSize * heightAspect))
if legendAnchor is None:
legendAnchor = legendLoc
return _Figure(self, figureName, axes, legendAnchor)
def setLabel(self, dataName, labelName):
self._labels[dataName] = labelName
def setLimits(self, dataName, limits):
self._limits[dataName] = limits
def setMargin(self, dataName, margin):
self._margins[dataName] = margin
def setScale(self, dataName, scaleName):
self._scales[dataName] = scaleName
# Set all plot default parameters
pyplot.rcParams.update({
'figure.dpi': 600,
'figure.figsize': (8, 4),
'font.size': 16,
'legend.fontsize': None,
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment