Created
March 30, 2013 21:50
-
-
Save anonymous/5278510 to your computer and use it in GitHub Desktop.
Matplotlib hexbin function modified to provide ability to scale hexes. Each hex is scaled by a factor returned by hexscale (callable) parameter which is provided with number of entries in this hex. If C is provided a seperate function can be used to reduce C for scaling (eg. you might want color to represent average counts in bin and scale to re…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pylab import * | |
import matplotlib.collections as mcoll | |
import matplotlib.colors as mcolors | |
def hexbin2(x, y, C = None, gridsize = 100, bins = None, | |
xscale = 'linear', yscale = 'linear', extent = None, | |
cmap=None, norm=None, vmin=None, vmax=None, | |
alpha=None, linewidths=None, edgecolors='none', | |
reduce_C_function = np.mean, mincnt=None, marginals=False, | |
hexscale=None, hexscale_reduce_C_function = None, | |
**kwargs): | |
ax = gca() | |
if not ax._hold: ax.cla() | |
ax._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) | |
x, y, C = cbook.delete_masked_points(x, y, C) | |
# Set the size of the hexagon grid | |
if iterable(gridsize): | |
nx, ny = gridsize | |
else: | |
nx = gridsize | |
ny = int(nx/math.sqrt(3)) | |
# Count the number of data in each hexagon | |
x = np.array(x, float) | |
y = np.array(y, float) | |
if xscale=='log': | |
if np.any(x <= 0.0): | |
raise ValueError("x contains non-positive values, so can not" | |
" be log-scaled") | |
x = np.log10(x) | |
if yscale=='log': | |
if np.any(y <= 0.0): | |
raise ValueError("y contains non-positive values, so can not" | |
" be log-scaled") | |
y = np.log10(y) | |
if extent is not None: | |
xmin, xmax, ymin, ymax = extent | |
else: | |
xmin = np.amin(x) | |
xmax = np.amax(x) | |
ymin = np.amin(y) | |
ymax = np.amax(y) | |
# In the x-direction, the hexagons exactly cover the region from | |
# xmin to xmax. Need some padding to avoid roundoff errors. | |
padding = 1.e-9 * (xmax - xmin) | |
xmin -= padding | |
xmax += padding | |
sx = (xmax-xmin) / nx | |
sy = (ymax-ymin) / ny | |
if marginals: | |
xorig = x.copy() | |
yorig = y.copy() | |
x = (x-xmin)/sx | |
y = (y-ymin)/sy | |
ix1 = np.round(x).astype(int) | |
iy1 = np.round(y).astype(int) | |
ix2 = np.floor(x).astype(int) | |
iy2 = np.floor(y).astype(int) | |
nx1 = nx + 1 | |
ny1 = ny + 1 | |
nx2 = nx | |
ny2 = ny | |
n = nx1*ny1+nx2*ny2 | |
d1 = (x-ix1)**2 + 3.0 * (y-iy1)**2 | |
d2 = (x-ix2-0.5)**2 + 3.0 * (y-iy2-0.5)**2 | |
bdist = (d1<d2) | |
# total counts | |
accum_counts = np.zeros(n) | |
lattice1 = accum_counts[:nx1*ny1] | |
lattice2 = accum_counts[nx1*ny1:] | |
lattice1.shape = (nx1,ny1) | |
lattice2.shape = (nx2,ny2) | |
if C is None: | |
accum = np.zeros(n) | |
# Create appropriate views into "accum" array. | |
lattice1 = accum[:nx1*ny1] | |
lattice2 = accum[nx1*ny1:] | |
lattice1.shape = (nx1,ny1) | |
lattice2.shape = (nx2,ny2) | |
for i in xrange(len(x)): | |
if bdist[i]: | |
if ((ix1[i] >= 0) and (ix1[i] < nx1) and | |
(iy1[i] >= 0) and (iy1[i] < ny1)): | |
lattice1[ix1[i], iy1[i]]+=1 | |
else: | |
if ((ix2[i] >= 0) and (ix2[i] < nx2) and | |
(iy2[i] >= 0) and (iy2[i] < ny2)): | |
lattice2[ix2[i], iy2[i]]+=1 | |
# threshold | |
if mincnt is not None: | |
for i in xrange(nx1): | |
for j in xrange(ny1): | |
if lattice1[i,j]<mincnt: | |
lattice1[i,j] = np.nan | |
for i in xrange(nx2): | |
for j in xrange(ny2): | |
if lattice2[i,j]<mincnt: | |
lattice2[i,j] = np.nan | |
accum = np.hstack(( | |
lattice1.astype(float).ravel(), lattice2.astype(float).ravel())) | |
good_idxs = ~np.isnan(accum) | |
accum_hexscale = accum | |
else: | |
if mincnt is None: | |
mincnt = 0 | |
# create accumulation arrays | |
lattice1 = np.empty((nx1,ny1),dtype=object) | |
for i in xrange(nx1): | |
for j in xrange(ny1): | |
lattice1[i,j] = [] | |
lattice2 = np.empty((nx2,ny2),dtype=object) | |
for i in xrange(nx2): | |
for j in xrange(ny2): | |
lattice2[i,j] = [] | |
for i in xrange(len(x)): | |
if bdist[i]: | |
if ((ix1[i] >= 0) and (ix1[i] < nx1) and | |
(iy1[i] >= 0) and (iy1[i] < ny1)): | |
lattice1[ix1[i], iy1[i]].append( C[i] ) | |
else: | |
if ((ix2[i] >= 0) and (ix2[i] < nx2) and | |
(iy2[i] >= 0) and (iy2[i] < ny2)): | |
lattice2[ix2[i], iy2[i]].append( C[i] ) | |
for i in xrange(nx1): | |
for j in xrange(ny1): | |
vals = lattice1[i,j] | |
if len(vals)>mincnt: | |
lattice1[i,j] = reduce_C_function( vals ) | |
else: | |
lattice1[i,j] = np.nan | |
for i in xrange(nx2): | |
for j in xrange(ny2): | |
vals = lattice2[i,j] | |
if len(vals)>mincnt: | |
lattice2[i,j] = reduce_C_function( vals ) | |
else: | |
lattice2[i,j] = np.nan | |
accum = np.hstack(( | |
lattice1.astype(float).ravel(), lattice2.astype(float).ravel())) | |
good_idxs = ~np.isnan(accum) | |
if hexscale is not None: | |
if hexscale_reduce_C_function is None: | |
hexscale_reduce_C_function = reduce_C_function | |
# create accumulation arrays | |
lattice1 = np.empty((nx1,ny1),dtype=object) | |
for i in xrange(nx1): | |
for j in xrange(ny1): | |
lattice1[i,j] = [] | |
lattice2 = np.empty((nx2,ny2),dtype=object) | |
for i in xrange(nx2): | |
for j in xrange(ny2): | |
lattice2[i,j] = [] | |
for i in xrange(len(x)): | |
if bdist[i]: | |
if ((ix1[i] >= 0) and (ix1[i] < nx1) and | |
(iy1[i] >= 0) and (iy1[i] < ny1)): | |
lattice1[ix1[i], iy1[i]].append( C[i] ) | |
else: | |
if ((ix2[i] >= 0) and (ix2[i] < nx2) and | |
(iy2[i] >= 0) and (iy2[i] < ny2)): | |
lattice2[ix2[i], iy2[i]].append( C[i] ) | |
for i in xrange(nx1): | |
for j in xrange(ny1): | |
vals = lattice1[i,j] | |
if len(vals)>mincnt: | |
lattice1[i,j] = hexscale_reduce_C_function( vals ) | |
else: | |
lattice1[i,j] = np.nan | |
for i in xrange(nx2): | |
for j in xrange(ny2): | |
vals = lattice2[i,j] | |
if len(vals)>mincnt: | |
lattice2[i,j] = hexscale_reduce_C_function( vals ) | |
else: | |
lattice2[i,j] = np.nan | |
accum_hexscale = np.hstack(( | |
lattice1.astype(float).ravel(), lattice2.astype(float).ravel())) | |
px = xmin + sx * np.array([ 0.5, 0.5, 0.0, -0.5, -0.5, 0.0]) | |
py = ymin + sy * np.array([-0.5, 0.5, 1.0, 0.5, -0.5, -1.0]) / 3.0 | |
polygons = np.zeros((6, n, 2), float) | |
polygons[:,:nx1*ny1,0] = np.repeat(np.arange(nx1), ny1) | |
polygons[:,:nx1*ny1,1] = np.tile(np.arange(ny1), nx1) | |
polygons[:,nx1*ny1:,0] = np.repeat(np.arange(nx2) + 0.5, ny2) | |
polygons[:,nx1*ny1:,1] = np.tile(np.arange(ny2), nx2) + 0.5 | |
# remove accumulation bins with no data | |
polygons = polygons[:,good_idxs,:] | |
accum = accum[good_idxs] | |
if hexscale is not None: | |
accum_hexscale = accum_hexscale[good_idxs] | |
polygons = np.transpose(polygons, axes=[1,0,2]) | |
polygons[:,:,0] *= sx | |
polygons[:,:,1] *= sy | |
polygons[:,:,0] += px | |
polygons[:,:,1] += py | |
if xscale=='log': | |
polygons[:,:,0] = 10**(polygons[:,:,0]) | |
xmin = 10**xmin | |
xmax = 10**xmax | |
ax.set_xscale('log') | |
if yscale=='log': | |
polygons[:,:,1] = 10**(polygons[:,:,1]) | |
ymin = 10**ymin | |
ymax = 10**ymax | |
ax.set_yscale('log') | |
if edgecolors=='none': | |
edgecolors = 'face' | |
if hexscale is not None: | |
# scale all polygons | |
new_polygons = [] | |
for vs, cnts in zip(polygons, accum_hexscale): | |
xs, ys = vs.T | |
mx = mean(xs) | |
my = mean(ys) | |
sc = hexscale(cnts) | |
xs = (xs-mx)*sc+mx | |
ys = (ys-my)*sc+my | |
new_polygons.append(zip(xs,ys)) | |
polygons = new_polygons | |
collection = mcoll.PolyCollection( | |
polygons, | |
sizes = None, | |
edgecolors = edgecolors, | |
linewidths = linewidths, | |
# transOffset = ax.transData | |
) | |
if isinstance(norm, mcolors.LogNorm): | |
if (accum==0).any(): | |
# make sure we have not zeros | |
accum += 1 | |
# autoscale the norm with curren accum values if it hasn't | |
# been set | |
if norm is not None: | |
if norm.vmin is None and norm.vmax is None: | |
norm.autoscale(accum) | |
# Transform accum if needed | |
if bins=='log': | |
accum = np.log10(accum+1) | |
elif bins!=None: | |
if not iterable(bins): | |
minimum, maximum = min(accum), max(accum) | |
bins-=1 # one less edge than bins | |
bins = minimum + (maximum-minimum)*np.arange(bins)/bins | |
bins = np.sort(bins) | |
accum = bins.searchsorted(accum) | |
if norm is not None: assert(isinstance(norm, mcolors.Normalize)) | |
collection.set_array(accum) | |
collection.set_cmap(cmap) | |
collection.set_norm(norm) | |
collection.set_alpha(alpha) | |
collection.update(kwargs) | |
if vmin is not None or vmax is not None: | |
collection.set_clim(vmin, vmax) | |
else: | |
collection.autoscale_None() | |
corners = ((xmin, ymin), (xmax, ymax)) | |
ax.update_datalim( corners) | |
ax.autoscale_view(tight=True) | |
# add the collection last | |
ax.add_collection(collection) | |
if not marginals: | |
ax._sci(collection) | |
return collection | |
if C is None: | |
C = np.ones(len(x)) | |
def coarse_bin(x, y, coarse): | |
ind = coarse.searchsorted(x).clip(0, len(coarse)-1) | |
mus = np.zeros(len(coarse)) | |
for i in range(len(coarse)): | |
mu = reduce_C_function(y[ind==i]) | |
mus[i] = mu | |
return mus | |
coarse = np.linspace(xmin, xmax, gridsize) | |
xcoarse = coarse_bin(xorig, C, coarse) | |
valid = ~np.isnan(xcoarse) | |
verts, values = [], [] | |
for i,val in enumerate(xcoarse): | |
thismin = coarse[i] | |
if i<len(coarse)-1: | |
thismax = coarse[i+1] | |
else: | |
thismax = thismin + np.diff(coarse)[-1] | |
if not valid[i]: continue | |
verts.append([(thismin, 0), (thismin, 0.05), (thismax, 0.05), (thismax, 0)]) | |
values.append(val) | |
values = np.array(values) | |
trans = mtransforms.blended_transform_factory( | |
ax.transData, ax.transAxes) | |
hbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face') | |
hbar.set_array(values) | |
hbar.set_cmap(cmap) | |
hbar.set_norm(norm) | |
hbar.set_alpha(alpha) | |
hbar.update(kwargs) | |
ax.add_collection(hbar) | |
coarse = np.linspace(ymin, ymax, gridsize) | |
ycoarse = coarse_bin(yorig, C, coarse) | |
valid = ~np.isnan(ycoarse) | |
verts, values = [], [] | |
for i,val in enumerate(ycoarse): | |
thismin = coarse[i] | |
if i<len(coarse)-1: | |
thismax = coarse[i+1] | |
else: | |
thismax = thismin + np.diff(coarse)[-1] | |
if not valid[i]: continue | |
verts.append([(0, thismin), (0.0, thismax), (0.05, thismax), (0.05, thismin)]) | |
values.append(val) | |
values = np.array(values) | |
trans = mtransforms.blended_transform_factory( | |
ax.transAxes, ax.transData) | |
vbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face') | |
vbar.set_array(values) | |
vbar.set_cmap(cmap) | |
vbar.set_norm(norm) | |
vbar.set_alpha(alpha) | |
vbar.update(kwargs) | |
ax.add_collection(vbar) | |
collection.hbar = hbar | |
collection.vbar = vbar | |
def on_changed(collection): | |
hbar.set_cmap(collection.get_cmap()) | |
hbar.set_clim(collection.get_clim()) | |
vbar.set_cmap(collection.get_cmap()) | |
vbar.set_clim(collection.get_clim()) | |
collection.callbacksSM.connect('changed', on_changed) | |
ax._sci(collection) | |
return collection |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment