Last active
August 29, 2015 14:13
-
-
Save wwoods/66acb7c5c01a5ccab126 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A figure generator for slicing and displaying data. Example usage: | |
import pandas | |
data = pandas.DataFrame() | |
data["a"] = None | |
data["b"] = None | |
data.loc[0] = [ 1, 1 ] | |
data.loc[1] = [ 2, 2 ] | |
data.loc[2] = [ 3, 2 ] | |
maker = FigureMaker(data, "images/", imageExtension = "png") | |
maker.setLabel("a", "First column") | |
maker.setLabel("b", "Second column") | |
maker.setScale("b", "log") | |
with maker.new("a vs b") as fig: | |
fig.plotValue("a and b", "a", "b") | |
""" | |
import math | |
import os | |
import matplotlib.image | |
from matplotlib import pyplot | |
import PIL.Image | |
import numpy as np | |
# Should come from python-slugify! | |
from slugify import slugify | |
class _Figure(object): | |
styles = [ '-', '--', ':' ] | |
markers = [ ' ', 'd', '*', 'o', 's', 'v', '<' ] | |
SECOND_Y_COLOR = 'r' | |
def __init__(self, maker, figureName, axes, legendAnchor): | |
self.maker = maker | |
self.axes = axes | |
self.axes.locator_params(tight = True) | |
self._legendAnchor = legendAnchor | |
self.width = axes.figure.get_size_inches()[0] * axes.get_position().width | |
self.height = axes.figure.get_size_inches()[1] * axes.get_position().height | |
self._lineCount = 0 | |
self._figureName = figureName | |
self._xlabel = None | |
self._ylabel = None | |
self._ylabel2 = None | |
self._yaxis2 = None | |
def __enter__(self): | |
return self | |
def __exit__(self, type, value, tb): | |
if type is None: | |
# Draw our legend | |
extraArtists = [] | |
if self._legendAnchor is not None: | |
if isinstance(self._legendAnchor, (tuple, list)): | |
# Anchor! | |
l = self.axes.legend(loc = self._legendAnchor[0], bbox_to_anchor = self._legendAnchor[1:]) | |
else: | |
l = self.axes.legend(loc = self._legendAnchor) | |
extraArtists.append(l) | |
# Update second y axis colors | |
if self._yaxis2 is not None: | |
for t in self._yaxis2.get_yticklabels(): | |
t.set_color(self.SECOND_Y_COLOR) | |
# Finish rendering to file, THEN render to IPython with the title | |
if self.maker.imageDestFolder is not None: | |
slugName = slugify(self._figureName) | |
fname = os.path.join(self.maker.imageDestFolder, slugName | |
+ '.' + self.maker.imageExtension) | |
if os.path.lexists(fname): | |
raise ValueError("Cannot name two figures {} (slugged {})".format( | |
self._figureName, slugName)) | |
matplotlib.pyplot.savefig(fname, dpi = 300, bbox_inches = 'tight', | |
bbox_extra_artists = extraArtists) | |
self.axes.set_title(self._figureName, fontsize = 24) | |
def plotImage(self, x, img, **kwargs): | |
"""Plot a column of image _paths_ as stagged images in a plot. | |
""" | |
def innerPlot(axes, dataSet, options): | |
xMin = dataSet[x].min() | |
xMax = dataSet[x].max() | |
def mapX(x): | |
"""Returns the imAxes.x coordinate for a data (axes.) x coordinate""" | |
return (x - xMin) / (xMax - xMin) | |
def unmapX(x): | |
"""Returns the axes.x coordinate for an imAxes.x coordinate""" | |
return xMin + x * (xMax - xMin) | |
def calcBounds(x, width): | |
"""Given an imAxes x and width, return the imAxes left and right.""" | |
partLeft = x | |
return (x - width * partLeft, x + width * (1.0 - partLeft)) | |
# Use the twin axis for our images, and the primary x axis for ticks | |
# (and possibly log scaling) | |
imAxes = axes.twiny() | |
# Original axes must be set AFTER twinning (true for y, anyway) | |
axes.set_xlim(xMin, xMax) | |
axes.set_ylim(0.0, 1.0) | |
axes.set_ymargin(0.0) | |
axes.get_yaxis().set_ticks([]) | |
axes.plot([ xMin, xMax ], [ 0.0, 1.0 ], linestyle = ' ') | |
self._legendAnchor = None | |
# Now set up our twin axes | |
imAxes.set_xscale("linear") | |
imAxes.set_xlim(0.0, 1.0) | |
imAxes.set_xmargin(0.0) | |
imAxes.get_xaxis().set_ticks([]) | |
# Gutter margin for arrows | |
imMargin = 0.01 | |
imLowestY = imMargin | |
if axes.get_xscale() == "log": | |
def mapLogX(x): | |
return (math.log(x) - math.log(xMin)) / (math.log(xMax) - math.log(xMin)) | |
mapX = mapLogX | |
def unmapLogX(x): | |
return math.exp(x * (math.log(xMax) - math.log(xMin)) + math.log(xMin)) | |
unmapX = unmapLogX | |
# List of [ imData, imAxesAspect, dataX, imAxesY, imAxesWidth ] | |
images = [] | |
# imAxes are bound on X between [0, 1] and Y between [0, 1]. So, | |
# aspect is same as figure width and height. | |
imAxesAspect = self.height / self.width | |
for coord, imPath in zip(dataSet[x].values, dataSet[img].values): | |
# Load and ensure RGBA, flip to correspond to matplotlib's coordinates | |
image = PIL.Image.open(imPath).convert("RGBA") | |
# nativeAspect is the aspect ratio of the image | |
nativeAspect = image.size[1] * 1.0 / image.size[0] | |
axesAspect = nativeAspect / imAxesAspect | |
images.append([ np.asarray(image), axesAspect, coord, 1, min(1.0, (1.0 - imMargin) / axesAspect) ]) | |
# Now figure out positions.... this is slow, but we want them all to be the | |
# same scale, so...just stack them ideally. | |
baseScale = 1.0 | |
while True: | |
ok = True | |
for i in range(len(images)): | |
# For each image, detect a "safe y" based on images before it | |
safeY = 1.0 | |
myBoundsX = calcBounds(mapX(images[i][2]), images[i][4] * baseScale) | |
for j in range(i): | |
myBoundsY = (safeY - images[i][4] * baseScale * images[i][1], safeY) | |
boundsX = calcBounds(mapX(images[j][2]), images[j][4] * baseScale) | |
boundsY = (images[j][3] - images[j][4] * baseScale * images[j][1], images[j][3]) | |
if myBoundsX[0] <= boundsX[1] and myBoundsX[1] >= boundsX[0]: | |
if myBoundsY[0] <= boundsY[1] and myBoundsY[1] >= boundsY[0]: | |
# Overlap! We can only move down, since we start at the top | |
safeY = boundsY[0] - imMargin | |
if safeY - images[i][4] * baseScale * images[i][1] < imLowestY: | |
ok = False | |
break | |
images[i][3] = safeY | |
if not ok: | |
break | |
if ok: | |
break | |
baseScale *= 0.9 | |
# Now that we've calculated the display, render everything | |
for imData, axesAspect, dataX, imAxesY, imAxesWidth in images: | |
imX = mapX(dataX) | |
# Apply our new scaling | |
imAxesWidth *= baseScale | |
imYBottom = imAxesY - axesAspect * imAxesWidth | |
bounds = calcBounds(imX, imAxesWidth) | |
imAxes.imshow(imData, interpolation = 'nearest', aspect = 'auto', | |
extent = (bounds[0], bounds[1], imYBottom, imAxesY)) | |
arrowX = (bounds[0] + bounds[1]) * 0.5 | |
imAxes.annotate("", | |
xy = (imX, 0), xycoords = 'data', | |
xytext = (arrowX, imYBottom), textcoords = 'data', | |
arrowprops = dict(arrowstyle = "->", connectionstyle = "arc3")) | |
self._callPlot(x, img, innerPlot, label = None, **kwargs) | |
def plotError(self, label, x, y, y2, **kwargs): | |
self._callPlot( | |
x, y, | |
lambda axes, dataSet, options: axes.errorbar(dataSet[x], dataSet[y], dataSet[y2], capsize = 6, **options), | |
label, **kwargs) | |
def plotValue(self, label, x, y, **kwargs): | |
self._callPlot(x, y, | |
lambda axes, dataSet, options: axes.plot(dataSet[x], dataSet[y], **options), | |
label, **kwargs) | |
def _callPlot(self, xDataName, yDataName, plotMethod, label, **kwargs): | |
dataSet = self.maker.dataSet | |
for k, v in kwargs.iteritems(): | |
if k.endswith("_lt"): | |
k = k[:-3] | |
dataSet = dataSet.loc[dataSet[k] < v] | |
elif k.endswith("_lte"): | |
k = k[:-4] | |
dataSet = dataSet.loc[dataSet[k] <= v] | |
elif k.endswith("_gt"): | |
k = k[:-3] | |
dataSet = dataSet.loc[dataSet[k] > v] | |
elif k.endswith("_gte"): | |
k = k[:-4] | |
dataSet = dataSet.loc[dataSet[k] >= v] | |
else: | |
dataSet = dataSet.loc[dataSet[k] == v] | |
markEvery = 5.0 / self.width * (3. + 1 * self._lineCount) | |
plotOptions = { | |
'label': label, | |
'linestyle': self.styles[self._lineCount % len(self.styles)], | |
'marker': self.markers[self._lineCount % len(self.markers)], | |
'markeredgecolor': 'none', | |
'markersize': 12, | |
'markevery': (markEvery, markEvery), | |
} | |
if plotOptions['marker'] == '*': | |
# stars are small | |
plotOptions['markersize'] *= 1.5 | |
if len(dataSet) == 0: | |
return | |
axes = self.axes | |
# Enforce X and Y axis general properties BEFORE render, since e.g. | |
# image rendering needs to know if something is a log distribution. | |
xLabel = self.maker._labels.get(xDataName, xDataName) | |
if self._xlabel is None: | |
self._xlabel = xLabel | |
axes.set_xlabel(self._xlabel) | |
axes.set_xscale(self.maker._scales.get(xDataName, "linear")) | |
elif self._xlabel != xLabel: | |
raise ValueError("Using multiple x axes in one plot? Unwise!") | |
yLabel = self.maker._labels.get(yDataName, yDataName) | |
if self._ylabel is None: | |
self._ylabel = yLabel | |
axes.set_ylabel(self._ylabel) | |
axes.set_yscale(self.maker._scales.get(yDataName, "linear")) | |
elif self._ylabel != yLabel: | |
if self._ylabel2 == yLabel: | |
# Ok, plot to second | |
axes = self._yaxis2 | |
pass | |
elif self._ylabel2 is None: | |
# Create a second axis and plot to that | |
self._ylabel2 = yLabel | |
self._yaxis2 = axes = axes.twinx() | |
axes.set_ylabel(self._ylabel2, color = self.SECOND_Y_COLOR) | |
axes.set_yscale(self.maker._scales.get(yDataName, "linear")) | |
else: | |
raise ValueError("Using multiple y axes in one plot? Unwise!") | |
plotMethod(axes, dataSet, plotOptions) | |
# --- Update xlimits and margin based on data | |
dataLimits = (axes.dataLim.xmin, axes.dataLim.xmax) | |
# Fully specified is OK, but a single None doesn't really work with matplotlib | |
limits = self.maker._limits.get(xDataName, (None, None)) | |
if limits[0] is None or limits[1] is None: | |
axes.set_xlim(auto = True) | |
xMin, xMax = axes.get_xbound() | |
if limits[0] is not None: | |
xMin = limits[0] | |
if limits[1] is not None: | |
xMax = limits[1] | |
axes.set_xlim(xMin, xMax) | |
else: | |
axes.set_xlim(limits) | |
# Log scale margins are broken :( | |
xMargin = self.maker.getMargin(xDataName) | |
if axes.get_xscale() != "log": | |
axes.set_xmargin(xMargin) | |
else: | |
axes.set_xlim(auto = True) | |
xMin, xMax = axes.get_xbound() | |
logMin = math.log(xMin) | |
logMax = math.log(xMax) | |
diff = logMax - logMin | |
logMin -= diff * xMargin | |
logMax += diff * xMargin | |
axes.set_xlim(math.exp(logMin), math.exp(logMax)) | |
# --- Update ylimits and margin based on data | |
dataLimits = (axes.dataLim.ymin, axes.dataLim.ymax) | |
# Fully specified is OK, but a single None doesn't really work with matplotlib | |
limits = self.maker._limits.get(yDataName, (None, None)) | |
axes.set_ylim(limits[0] if limits[0] is not None else dataLimits[0], | |
limits[1] if limits[1] is not None else dataLimits[1]) | |
# Log scale margins are broken :( :( | |
yMargin = self.maker.getMargin(yDataName) | |
if axes.get_yscale() != "log": | |
axes.set_ymargin(yMargin) | |
else: | |
logMin = math.log(dataLimits[0]) | |
logMax = math.log(dataLimits[1]) | |
diff = logMax - logMin | |
logMin -= diff * xMargin | |
logMax += diff * xMargin | |
axes.set_ylim(math.exp(logMin), math.exp(logMax)) | |
self._lineCount += 1 | |
class FigureMaker(object): | |
def __init__(self, dataSet, imageDestFolder = None, | |
imageExtension = 'png', defaultHeightAspect = 0.618): | |
self.dataSet = dataSet | |
self.imageDestFolder = imageDestFolder | |
self.imageExtension = imageExtension | |
# Clear out all old png files | |
if imageDestFolder is not None: | |
try: | |
os.makedirs(imageDestFolder) | |
except OSError, e: | |
# Already exists | |
if e.errno != 17: | |
raise | |
for fn in os.listdir(imageDestFolder): | |
os.remove(os.path.join(imageDestFolder, fn)) | |
self._defaults = { 'heightAspect': defaultHeightAspect } | |
self._labels = {} | |
self._limits = {} | |
self._margins = {} | |
self._scales = {} | |
self._marginDefault = 0.00 | |
self._unnamedCount = 0 | |
def getMargin(self, dataName): | |
return self._margins.get(dataName, self._marginDefault) | |
def new(self, figureName = None, heightAspect = None, scale = 1.0, legendLoc = 0, | |
legendAnchor = None): | |
"""legendLoc - The loc argument for a legend in pyplot. 1 is upper right, | |
2 is lower right, etc. Defaults to 0 - best. Can also be strings 'upper right', etc. | |
Pass legendLoc = None to disable the legend. | |
legendAnchor - Overrides legendLoc. 3-tuple of (corner, x, y). E.g. ('upper right', 0, 0) | |
will position the legend's upper right corner at 0, 0""" | |
baseSize = 10.0 * scale | |
if figureName is None: | |
self._unnamedCount += 1 | |
figureName = "Unnamed {}".format(self._unnamedCount) | |
if heightAspect is None: | |
heightAspect = self._defaults['heightAspect'] | |
fig, axes = pyplot.subplots(1, 1, | |
figsize = (baseSize, baseSize * heightAspect)) | |
if legendAnchor is None: | |
legendAnchor = legendLoc | |
return _Figure(self, figureName, axes, legendAnchor) | |
def setLabel(self, dataName, labelName): | |
self._labels[dataName] = labelName | |
def setLimits(self, dataName, limits): | |
self._limits[dataName] = limits | |
def setMargin(self, dataName, margin): | |
self._margins[dataName] = margin | |
def setScale(self, dataName, scaleName): | |
self._scales[dataName] = scaleName | |
# Set all plot default parameters | |
pyplot.rcParams.update({ | |
'figure.dpi': 600, | |
'figure.figsize': (8, 4), | |
'font.size': 16, | |
'legend.fontsize': None, | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment