Skip to content

Instantly share code, notes, and snippets.

Created March 30, 2013 21:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/5278510 to your computer and use it in GitHub Desktop.
Save anonymous/5278510 to your computer and use it in GitHub Desktop.
Matplotlib hexbin function modified to provide ability to scale hexes. Each hex is scaled by a factor returned by hexscale (callable) parameter which is provided with number of entries in this hex. If C is provided a seperate function can be used to reduce C for scaling (eg. you might want color to represent average counts in bin and scale to re…
from pylab import *
import matplotlib.collections as mcoll
import matplotlib.colors as mcolors
def hexbin2(x, y, C = None, gridsize = 100, bins = None,
xscale = 'linear', yscale = 'linear', extent = None,
cmap=None, norm=None, vmin=None, vmax=None,
alpha=None, linewidths=None, edgecolors='none',
reduce_C_function = np.mean, mincnt=None, marginals=False,
hexscale=None, hexscale_reduce_C_function = None,
**kwargs):
ax = gca()
if not ax._hold: ax.cla()
ax._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
x, y, C = cbook.delete_masked_points(x, y, C)
# Set the size of the hexagon grid
if iterable(gridsize):
nx, ny = gridsize
else:
nx = gridsize
ny = int(nx/math.sqrt(3))
# Count the number of data in each hexagon
x = np.array(x, float)
y = np.array(y, float)
if xscale=='log':
if np.any(x <= 0.0):
raise ValueError("x contains non-positive values, so can not"
" be log-scaled")
x = np.log10(x)
if yscale=='log':
if np.any(y <= 0.0):
raise ValueError("y contains non-positive values, so can not"
" be log-scaled")
y = np.log10(y)
if extent is not None:
xmin, xmax, ymin, ymax = extent
else:
xmin = np.amin(x)
xmax = np.amax(x)
ymin = np.amin(y)
ymax = np.amax(y)
# In the x-direction, the hexagons exactly cover the region from
# xmin to xmax. Need some padding to avoid roundoff errors.
padding = 1.e-9 * (xmax - xmin)
xmin -= padding
xmax += padding
sx = (xmax-xmin) / nx
sy = (ymax-ymin) / ny
if marginals:
xorig = x.copy()
yorig = y.copy()
x = (x-xmin)/sx
y = (y-ymin)/sy
ix1 = np.round(x).astype(int)
iy1 = np.round(y).astype(int)
ix2 = np.floor(x).astype(int)
iy2 = np.floor(y).astype(int)
nx1 = nx + 1
ny1 = ny + 1
nx2 = nx
ny2 = ny
n = nx1*ny1+nx2*ny2
d1 = (x-ix1)**2 + 3.0 * (y-iy1)**2
d2 = (x-ix2-0.5)**2 + 3.0 * (y-iy2-0.5)**2
bdist = (d1<d2)
# total counts
accum_counts = np.zeros(n)
lattice1 = accum_counts[:nx1*ny1]
lattice2 = accum_counts[nx1*ny1:]
lattice1.shape = (nx1,ny1)
lattice2.shape = (nx2,ny2)
if C is None:
accum = np.zeros(n)
# Create appropriate views into "accum" array.
lattice1 = accum[:nx1*ny1]
lattice2 = accum[nx1*ny1:]
lattice1.shape = (nx1,ny1)
lattice2.shape = (nx2,ny2)
for i in xrange(len(x)):
if bdist[i]:
if ((ix1[i] >= 0) and (ix1[i] < nx1) and
(iy1[i] >= 0) and (iy1[i] < ny1)):
lattice1[ix1[i], iy1[i]]+=1
else:
if ((ix2[i] >= 0) and (ix2[i] < nx2) and
(iy2[i] >= 0) and (iy2[i] < ny2)):
lattice2[ix2[i], iy2[i]]+=1
# threshold
if mincnt is not None:
for i in xrange(nx1):
for j in xrange(ny1):
if lattice1[i,j]<mincnt:
lattice1[i,j] = np.nan
for i in xrange(nx2):
for j in xrange(ny2):
if lattice2[i,j]<mincnt:
lattice2[i,j] = np.nan
accum = np.hstack((
lattice1.astype(float).ravel(), lattice2.astype(float).ravel()))
good_idxs = ~np.isnan(accum)
accum_hexscale = accum
else:
if mincnt is None:
mincnt = 0
# create accumulation arrays
lattice1 = np.empty((nx1,ny1),dtype=object)
for i in xrange(nx1):
for j in xrange(ny1):
lattice1[i,j] = []
lattice2 = np.empty((nx2,ny2),dtype=object)
for i in xrange(nx2):
for j in xrange(ny2):
lattice2[i,j] = []
for i in xrange(len(x)):
if bdist[i]:
if ((ix1[i] >= 0) and (ix1[i] < nx1) and
(iy1[i] >= 0) and (iy1[i] < ny1)):
lattice1[ix1[i], iy1[i]].append( C[i] )
else:
if ((ix2[i] >= 0) and (ix2[i] < nx2) and
(iy2[i] >= 0) and (iy2[i] < ny2)):
lattice2[ix2[i], iy2[i]].append( C[i] )
for i in xrange(nx1):
for j in xrange(ny1):
vals = lattice1[i,j]
if len(vals)>mincnt:
lattice1[i,j] = reduce_C_function( vals )
else:
lattice1[i,j] = np.nan
for i in xrange(nx2):
for j in xrange(ny2):
vals = lattice2[i,j]
if len(vals)>mincnt:
lattice2[i,j] = reduce_C_function( vals )
else:
lattice2[i,j] = np.nan
accum = np.hstack((
lattice1.astype(float).ravel(), lattice2.astype(float).ravel()))
good_idxs = ~np.isnan(accum)
if hexscale is not None:
if hexscale_reduce_C_function is None:
hexscale_reduce_C_function = reduce_C_function
# create accumulation arrays
lattice1 = np.empty((nx1,ny1),dtype=object)
for i in xrange(nx1):
for j in xrange(ny1):
lattice1[i,j] = []
lattice2 = np.empty((nx2,ny2),dtype=object)
for i in xrange(nx2):
for j in xrange(ny2):
lattice2[i,j] = []
for i in xrange(len(x)):
if bdist[i]:
if ((ix1[i] >= 0) and (ix1[i] < nx1) and
(iy1[i] >= 0) and (iy1[i] < ny1)):
lattice1[ix1[i], iy1[i]].append( C[i] )
else:
if ((ix2[i] >= 0) and (ix2[i] < nx2) and
(iy2[i] >= 0) and (iy2[i] < ny2)):
lattice2[ix2[i], iy2[i]].append( C[i] )
for i in xrange(nx1):
for j in xrange(ny1):
vals = lattice1[i,j]
if len(vals)>mincnt:
lattice1[i,j] = hexscale_reduce_C_function( vals )
else:
lattice1[i,j] = np.nan
for i in xrange(nx2):
for j in xrange(ny2):
vals = lattice2[i,j]
if len(vals)>mincnt:
lattice2[i,j] = hexscale_reduce_C_function( vals )
else:
lattice2[i,j] = np.nan
accum_hexscale = np.hstack((
lattice1.astype(float).ravel(), lattice2.astype(float).ravel()))
px = xmin + sx * np.array([ 0.5, 0.5, 0.0, -0.5, -0.5, 0.0])
py = ymin + sy * np.array([-0.5, 0.5, 1.0, 0.5, -0.5, -1.0]) / 3.0
polygons = np.zeros((6, n, 2), float)
polygons[:,:nx1*ny1,0] = np.repeat(np.arange(nx1), ny1)
polygons[:,:nx1*ny1,1] = np.tile(np.arange(ny1), nx1)
polygons[:,nx1*ny1:,0] = np.repeat(np.arange(nx2) + 0.5, ny2)
polygons[:,nx1*ny1:,1] = np.tile(np.arange(ny2), nx2) + 0.5
# remove accumulation bins with no data
polygons = polygons[:,good_idxs,:]
accum = accum[good_idxs]
if hexscale is not None:
accum_hexscale = accum_hexscale[good_idxs]
polygons = np.transpose(polygons, axes=[1,0,2])
polygons[:,:,0] *= sx
polygons[:,:,1] *= sy
polygons[:,:,0] += px
polygons[:,:,1] += py
if xscale=='log':
polygons[:,:,0] = 10**(polygons[:,:,0])
xmin = 10**xmin
xmax = 10**xmax
ax.set_xscale('log')
if yscale=='log':
polygons[:,:,1] = 10**(polygons[:,:,1])
ymin = 10**ymin
ymax = 10**ymax
ax.set_yscale('log')
if edgecolors=='none':
edgecolors = 'face'
if hexscale is not None:
# scale all polygons
new_polygons = []
for vs, cnts in zip(polygons, accum_hexscale):
xs, ys = vs.T
mx = mean(xs)
my = mean(ys)
sc = hexscale(cnts)
xs = (xs-mx)*sc+mx
ys = (ys-my)*sc+my
new_polygons.append(zip(xs,ys))
polygons = new_polygons
collection = mcoll.PolyCollection(
polygons,
sizes = None,
edgecolors = edgecolors,
linewidths = linewidths,
# transOffset = ax.transData
)
if isinstance(norm, mcolors.LogNorm):
if (accum==0).any():
# make sure we have not zeros
accum += 1
# autoscale the norm with curren accum values if it hasn't
# been set
if norm is not None:
if norm.vmin is None and norm.vmax is None:
norm.autoscale(accum)
# Transform accum if needed
if bins=='log':
accum = np.log10(accum+1)
elif bins!=None:
if not iterable(bins):
minimum, maximum = min(accum), max(accum)
bins-=1 # one less edge than bins
bins = minimum + (maximum-minimum)*np.arange(bins)/bins
bins = np.sort(bins)
accum = bins.searchsorted(accum)
if norm is not None: assert(isinstance(norm, mcolors.Normalize))
collection.set_array(accum)
collection.set_cmap(cmap)
collection.set_norm(norm)
collection.set_alpha(alpha)
collection.update(kwargs)
if vmin is not None or vmax is not None:
collection.set_clim(vmin, vmax)
else:
collection.autoscale_None()
corners = ((xmin, ymin), (xmax, ymax))
ax.update_datalim( corners)
ax.autoscale_view(tight=True)
# add the collection last
ax.add_collection(collection)
if not marginals:
ax._sci(collection)
return collection
if C is None:
C = np.ones(len(x))
def coarse_bin(x, y, coarse):
ind = coarse.searchsorted(x).clip(0, len(coarse)-1)
mus = np.zeros(len(coarse))
for i in range(len(coarse)):
mu = reduce_C_function(y[ind==i])
mus[i] = mu
return mus
coarse = np.linspace(xmin, xmax, gridsize)
xcoarse = coarse_bin(xorig, C, coarse)
valid = ~np.isnan(xcoarse)
verts, values = [], []
for i,val in enumerate(xcoarse):
thismin = coarse[i]
if i<len(coarse)-1:
thismax = coarse[i+1]
else:
thismax = thismin + np.diff(coarse)[-1]
if not valid[i]: continue
verts.append([(thismin, 0), (thismin, 0.05), (thismax, 0.05), (thismax, 0)])
values.append(val)
values = np.array(values)
trans = mtransforms.blended_transform_factory(
ax.transData, ax.transAxes)
hbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
hbar.set_array(values)
hbar.set_cmap(cmap)
hbar.set_norm(norm)
hbar.set_alpha(alpha)
hbar.update(kwargs)
ax.add_collection(hbar)
coarse = np.linspace(ymin, ymax, gridsize)
ycoarse = coarse_bin(yorig, C, coarse)
valid = ~np.isnan(ycoarse)
verts, values = [], []
for i,val in enumerate(ycoarse):
thismin = coarse[i]
if i<len(coarse)-1:
thismax = coarse[i+1]
else:
thismax = thismin + np.diff(coarse)[-1]
if not valid[i]: continue
verts.append([(0, thismin), (0.0, thismax), (0.05, thismax), (0.05, thismin)])
values.append(val)
values = np.array(values)
trans = mtransforms.blended_transform_factory(
ax.transAxes, ax.transData)
vbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
vbar.set_array(values)
vbar.set_cmap(cmap)
vbar.set_norm(norm)
vbar.set_alpha(alpha)
vbar.update(kwargs)
ax.add_collection(vbar)
collection.hbar = hbar
collection.vbar = vbar
def on_changed(collection):
hbar.set_cmap(collection.get_cmap())
hbar.set_clim(collection.get_clim())
vbar.set_cmap(collection.get_cmap())
vbar.set_clim(collection.get_clim())
collection.callbacksSM.connect('changed', on_changed)
ax._sci(collection)
return collection
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment