Skip to content

Instantly share code, notes, and snippets.

@adrn
Created April 3, 2013 13:25
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save adrn/5301190 to your computer and use it in GitHub Desktop.
Save adrn/5301190 to your computer and use it in GitHub Desktop.
Make a scatter-plot matrix with matplotlib
# coding: utf-8
""" Create a scatter-plot matrix using Matplotlib. """
from __future__ import division, print_function
__author__ = "adrn <adrn@astro.columbia.edu>"
# Standard library
import os, sys
# Third-party
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
def scatter_plot_matrix(data, labels=None, axes=None, subplots_kwargs=dict(),
scatter_kwargs=dict()):
""" Create a scatter plot matrix from the given data.
Parameters
----------
data : numpy.ndarray
A numpy array containined the scatter data to plot. The data
should be shape MxN where M is the number of dimensions and
with N data points.
labels : numpy.ndarray (optional)
A numpy array of length M containing the axis labels.
axes : matplotlib Axes array (optional)
If you've already created the axes objects, pass this in to
plot the data on that.
subplots_kwargs : dict (optional)
A dictionary of keyword arguments to pass to the
matplotlib.pyplot.subplots call. Note: only relevant if axes=None.
scatter_kwargs : dict (optional)
A dictionary of keyword arguments to pass to the
matplotlib.pyplot.scatter function calls.
"""
try:
M,N = data.shape
if M > N: raise ValueError()
except ValueError: # too many values to unpack
raise ValueError("Invalid data shape {0}. You must pass in an array of "
"shape (M, N) where N > M.".format(data.shape))
if labels == None:
labels = [None]*M
if axes == None:
skwargs = subplots_kwargs.copy()
skwargs["sharex"] = True if not skwargs.has_key("sharex") else skwargs["sharex"]
skwargs["sharey"] = True if not skwargs.has_key("sharey") else skwargs["sharey"]
fig, axes = plt.subplots(M, M, **skwargs)
sc_kwargs = scatter_kwargs.copy()
sc_kwargs["edgecolor"] = "none" if not sc_kwargs.has_key("edgecolor") else sc_kwargs["edgecolor"]
sc_kwargs["c"] = "k" if not sc_kwargs.has_key("c") else sc_kwargs["c"]
sc_kwargs["s"] = 10 if not sc_kwargs.has_key("s") else sc_kwargs["s"]
xticks = yticks = None
for ii in range(M):
for jj in range(M):
axes[ii,jj].scatter(data[jj], data[ii], **sc_kwargs)
if yticks == None:
yticks = axes[ii,jj].get_yticks()[1:-1]
if xticks == None:
xticks = axes[ii,jj].get_xticks()[1:-1]
# first column
if jj == 0:
axes[ii,jj].set_ylabel(labels[ii])
# Hack so ticklabels don't overlap
axes[ii,jj].yaxis.set_ticks(yticks)
# last row
if ii == M-1:
axes[ii,jj].set_xlabel(labels[jj])
# Hack so ticklabels don't overlap
axes[ii,jj].xaxis.set_ticks(xticks)
fig = axes[0,0].figure
fig.subplots_adjust(hspace=0.0, wspace=0.0, left=0.08, bottom=0.08, top=0.9, right=0.9 )
return fig, axes
@keflavich
Copy link

Add some examples! Also, are you using astropy.units anywhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment