Skip to content

Instantly share code, notes, and snippets.

@janmtl
Created March 26, 2016 09:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save janmtl/5f0ce81982f6faf362dd to your computer and use it in GitHub Desktop.
Save janmtl/5f0ce81982f6faf362dd to your computer and use it in GitHub Desktop.
def dfgridplot(data, x, y, plotfn, aggfuncs,
rownames=None, colnames=None, linenames=None, **kwargs):
aggsindex = None
if rownames:
if type(rownames) is str:
rownames = list(rownames)
rows = data[rownames].drop_duplicates()
aggsindex = rownames
else:
rownames = []
rows = [None]
if colnames:
if type(colnames) is str:
colnames = list(colnames)
cols = data[colnames].drop_duplicates()
if aggsindex:
aggsindex = aggsindex + colnames
else:
aggsindex = colnames
else:
colnames = []
cols = [None]
if linenames:
if type(linenames) is str:
linenames = list(linenames)
if aggsindex:
aggsindex = aggsindex + linenames
else:
aggsindex = linenames
aggs = {}
for key, func in aggfuncs.iteritems():
aggs[key] = data.pivot_table(index=aggsindex,
columns=x, values=y, aggfunc=func)
aggs[key].columns = aggs[key].columns.droplevel(0)
fig, axs = plt.subplots(len(rows), len(cols), **kwargs)
# when we want to go row-wise in plt.subplots it still produce a
# 1-dim array so we have to artifically make it 2-dim
print axs
if len(rows) == 1:
axsv = [axs]
else:
axsv = axs
print axsv
axsdf = pd.DataFrame(axsv, index=rows, columns=cols)
for rowname, axrow in axsdf.iteritems():
for colname, ax in axrow.iteritems():
if not rowname:
rowname = ()
if not colname:
colname = ()
datasel = colname + rowname + (slice(None),)*len(linenames)
plotdata = {}
for key, aggdata in aggs.iteritems():
plotdata[key] = aggdata.loc[datasel, :]
plotdata[key].index = plotdata[key]\
.index.droplevel(range(len(rownames) + len(colnames)))
plotfn(ax, plotdata, rowname, colname)
return fig, axs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment