-
-
Save josef-pkt/1511969 to your computer and use it in GitHub Desktop.
Build a regular grid and assigns each data point to a cell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Author: Adrien Gaidon | |
https://gist.github.com/1509853 | |
BSD on mailing list | |
Josef Perktold for numpy 1.5 compatibility | |
''' | |
import numpy as np | |
#for numpy < 1.6 | |
if not hasattr(np, 'ravel_multi_index'): | |
def ravel_multi_index(arr, shape): | |
#c order only | |
base_c = np.arange(np.prod(shape)).reshape(*shape) | |
return base_c[tuple(arr.tolist())] | |
def unravel_index(arr, shape): | |
return np.array([np.unravel_index(ai, shape) for ai in arr]) | |
else: | |
ravel_multi_index = np.ravel_multi_index | |
unravel_index = unravel_index | |
def regular_multidim_digitize(data, n_bins=3, lbes=None, rbes=None): | |
""" Build a regular grid and assigns each data point to a cell | |
Parameters | |
---------- | |
data: (n_points, n_dims) array, | |
the data that we which to digitize | |
n_bins: int or (n_dims, ) array-like, optional, default: 3, | |
per-dimension number of bins | |
lbes: float or (n_dims, ) array-like, optional, default: None, | |
per-dimension left-most bin edges | |
by default, use the min along each dimension | |
rbes: float or (n_dims, ) array-like, optional, default: None, | |
per-dimension right-most bin edges | |
by default, use the max along each dimension | |
Returns | |
------- | |
assignments: (n_points, ) array of integers, | |
the assignment index of each point | |
Notes | |
----- | |
"Regular" here means evenly-spaced along each dimension (possibly separately) | |
Each cell has its unique integer index which is the cell number in C-order | |
(last dimension varies fastest). | |
To obtain the cell coordinates from the cell numbers do: | |
np.array(zip(*np.unravel_index(assignments, n_bins))) | |
""" | |
n_points, n_dims = data.shape | |
if isinstance(n_bins, int): | |
n_bins = np.array([n_bins for i in range(n_dims)], dtype=np.int) | |
else: | |
n_bins = np.asarray(n_bins, dtype=np.int) | |
assert n_bins.shape == (n_dims,), "invalid n_bins: {0}".format(n_bins) | |
if lbes is None: | |
lbes = np.min(data, axis=0).astype(np.float) | |
else: | |
if isinstance(lbes, float): | |
lbes = np.array([lbes for i in range(n_dims)], dtype=np.float) | |
else: | |
lbes = np.asarray(lbes, dtype=np.float) | |
assert lbes.shape == n_dims, "Invalid lbes: {0}".format(lbes) | |
# check for overflow | |
assert lbes >= np.min(data, axis=0), "lbes not low enough" | |
if rbes is None: | |
rbes = np.max(data, axis=0).astype(np.float) | |
else: | |
if isinstance(rbes, float): | |
rbes = np.array([rbes for i in range(n_dims)], dtype=np.float) | |
else: | |
rbes = np.asarray(rbes, dtype=np.float) | |
assert rbes.shape == n_dims, "Invalid rbes: {0}".format(rbes) | |
# check for overflow | |
assert rbes <= np.max(data, axis=0), "rbes not high enough" | |
# get the bin-widths per dimension | |
bws = 1e-12 + (rbes - lbes) / n_bins # add small shift to have max in last bin | |
# get the per-dim bin multi-dim index of each point | |
dis = ((data - lbes[np.newaxis, :]) / bws[np.newaxis, :]).astype(np.int) | |
# get index of the flattened grid | |
#assignments = np.ravel_multi_index(dis.T, n_bins) #requires np>=1.6 | |
assignments = ravel_multi_index(dis.T, n_bins) | |
# DEBUG sanity check: | |
#assert np.alltrue(dis == np.array(zip(*np.unravel_index(assignments, n_bins)))) | |
return assignments, dis, n_bins | |
x = np.random.randn(10,3) | |
ass, dis, n_bins = regular_multidim_digitize(x, n_bins=5) | |
print np.alltrue(dis == unravel_index(ass, n_bins)) |
Thanks, I don't see any disadvantage to your suggestion.
I've used the same trick as daien in the past but usually with integers in mind, even if it wasn't restricted to them.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This looks quite nice and useful. Just one little comment on the
"small shift":
There is nothing to say that rbes and lbes are of order(1). Someone
might call this routine for a dataset spanning 1e-22 to 5e-22 and then
this won't work so well. I think that multiplying by (1 + 1e-12), or
something along those lines, might be better.