Skip to content

Instantly share code, notes, and snippets.

@dfm
Created March 30, 2012 14:45
Show Gist options
  • Save dfm/2252032 to your computer and use it in GitHub Desktop.
Save dfm/2252032 to your computer and use it in GitHub Desktop.
Triangle plots
__all__ = ["triplot"]
import itertools
import matplotlib.pyplot as pl
from matplotlib.ticker import MaxNLocator
import numpy as np
def triplot(samples, titles=None, sfactor=0.1):
"""
## Arguments
* `samples` (numpy.ndarray): The samples (nsamples, nvars).
* `titles` (list): A list of length nvars.
"""
nvars = samples.shape[-1]
if titles is None:
titles = [str(d) for d in range(nvars)]
ranges = [[np.min(samples[:,i])*(1-sfactor),
np.max(samples[:,i])*(1+sfactor)] for i in range(nvars)]
assert nvars == len(titles)
fig = pl.figure(figsize=(30, 30))
for i, (yi,xi) in enumerate(itertools.product(np.arange(nvars),
np.arange(nvars))):
if xi <= yi:
x, y = samples[:,xi], samples[:,yi]
ax = fig.add_subplot(nvars, nvars, i+1)
if xi == yi:
ax.hist(x, 50, histtype="step", color="k")
ax.set_yticklabels([])
ax.set_xlim(ranges[xi])
else:
ax.plot(x, y, ".k")
ax.set_xlim(ranges[xi])
ax.set_ylim(ranges[yi])
ax.xaxis.set_major_locator(MaxNLocator(5))
ax.yaxis.set_major_locator(MaxNLocator(5))
if xi > 0:
ax.set_yticklabels([])
else:
if yi > 0:
ax.set_ylabel(titles[yi])
if yi < nvars-1:
ax.set_xticklabels([])
else:
ax.set_xlabel(titles[xi])
for t in ax.get_xticklabels():
t.set_rotation(60)
if __name__ == "__main__":
nvars, nsamples = 10, 1000
samples = np.random.randn(nvars*nsamples).reshape(nsamples, nvars)
titles = ["Variable %d"%i for i in range(nvars)]
triplot(samples, titles)
pl.savefig("triplot.png")
@mfouesneau
Copy link

Hey Dan, maybe this will inspire you

def plotCorr(l, pars, plotfunc=None, lbls=None, _args, *_kwargs):
""" Plot correlation matrix between variables
inputs
l -- dictionary of variables (could be a Table)
pars -- parameters to use

            *args, **kwargs are forwarded to the plot function
    """
    fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}
    if not len(pars)-1 in fontmap:
            fontmap[len(pars)-1] = 3

    if lbls == None:
            lbls = pars

    k = 1
    axes = numpy.empty((len(pars),len(pars)), dtype=object)
    for j in range(len(pars)):
            newrow=True
            for i in range(len(pars)):
                    if i>j:
                            if j>0:
                                    sharex = axes[j-1, i]
                            else: 
                                    sharex = None
                            if i>0:
                                    sharey = axes[j, i-1]
                            else: 
                                    sharey = None
                            ax = subplot(len(pars)-1,len(pars)-1,k, sharey=sharey, sharex=sharex)
                            axes[j,i] = ax
                            if plotfunc == None:
                                    plot(l[pars[i]],l[pars[j]],',',**kwargs)
                            else:
                                    plotfunc(l[pars[i]],l[pars[j]],*args, **kwargs)

                            theme(ax=ax)
                            tlabels = gca().get_xticklabels()
                            setp(tlabels, 'fontsize', 2*fontmap[len(pars)-1])
                            tlabels = gca().get_yticklabels()
                            setp(tlabels, 'fontsize', 2*fontmap[len(pars)-1])
                            if not newrow:
                                    setp(ax.get_xticklabels()+ax.get_yticklabels(), visible=False)
                            else:
                                    xlabel(lbls[i], fontsize=2.*fontmap[len(pars)-1])
                                    ylabel(lbls[j], fontsize=2.*fontmap[len(pars)-1])
                                    newrow=False
                    if i!=j:
                            k+=1
    setMargins(hspace=0.0, wspace=0.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment