-
-
Save josef-pkt/1511969 to your computer and use it in GitHub Desktop.
Build a regular grid and assigns each data point to a cell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Author: Adrien Gaidon | |
https://gist.github.com/1509853 | |
BSD on mailing list | |
Josef Perktold for numpy 1.5 compatibility | |
''' | |
import numpy as np | |
#for numpy < 1.6 | |
if not hasattr(np, 'ravel_multi_index'): | |
def ravel_multi_index(arr, shape): | |
#c order only | |
base_c = np.arange(np.prod(shape)).reshape(*shape) | |
return base_c[tuple(arr.tolist())] | |
def unravel_index(arr, shape): | |
return np.array([np.unravel_index(ai, shape) for ai in arr]) | |
else: | |
ravel_multi_index = np.ravel_multi_index | |
unravel_index = unravel_index | |
def regular_multidim_digitize(data, n_bins=3, lbes=None, rbes=None): | |
""" Build a regular grid and assigns each data point to a cell | |
Parameters | |
---------- | |
data: (n_points, n_dims) array, | |
the data that we which to digitize | |
n_bins: int or (n_dims, ) array-like, optional, default: 3, | |
per-dimension number of bins | |
lbes: float or (n_dims, ) array-like, optional, default: None, | |
per-dimension left-most bin edges | |
by default, use the min along each dimension | |
rbes: float or (n_dims, ) array-like, optional, default: None, | |
per-dimension right-most bin edges | |
by default, use the max along each dimension | |
Returns | |
------- | |
assignments: (n_points, ) array of integers, | |
the assignment index of each point | |
Notes | |
----- | |
"Regular" here means evenly-spaced along each dimension (possibly separately) | |
Each cell has its unique integer index which is the cell number in C-order | |
(last dimension varies fastest). | |
To obtain the cell coordinates from the cell numbers do: | |
np.array(zip(*np.unravel_index(assignments, n_bins))) | |
""" | |
n_points, n_dims = data.shape | |
if isinstance(n_bins, int): | |
n_bins = np.array([n_bins for i in range(n_dims)], dtype=np.int) | |
else: | |
n_bins = np.asarray(n_bins, dtype=np.int) | |
assert n_bins.shape == (n_dims,), "invalid n_bins: {0}".format(n_bins) | |
if lbes is None: | |
lbes = np.min(data, axis=0).astype(np.float) | |
else: | |
if isinstance(lbes, float): | |
lbes = np.array([lbes for i in range(n_dims)], dtype=np.float) | |
else: | |
lbes = np.asarray(lbes, dtype=np.float) | |
assert lbes.shape == n_dims, "Invalid lbes: {0}".format(lbes) | |
# check for overflow | |
assert lbes >= np.min(data, axis=0), "lbes not low enough" | |
if rbes is None: | |
rbes = np.max(data, axis=0).astype(np.float) | |
else: | |
if isinstance(rbes, float): | |
rbes = np.array([rbes for i in range(n_dims)], dtype=np.float) | |
else: | |
rbes = np.asarray(rbes, dtype=np.float) | |
assert rbes.shape == n_dims, "Invalid rbes: {0}".format(rbes) | |
# check for overflow | |
assert rbes <= np.max(data, axis=0), "rbes not high enough" | |
# get the bin-widths per dimension | |
bws = 1e-12 + (rbes - lbes) / n_bins # add small shift to have max in last bin | |
# get the per-dim bin multi-dim index of each point | |
dis = ((data - lbes[np.newaxis, :]) / bws[np.newaxis, :]).astype(np.int) | |
# get index of the flattened grid | |
#assignments = np.ravel_multi_index(dis.T, n_bins) #requires np>=1.6 | |
assignments = ravel_multi_index(dis.T, n_bins) | |
# DEBUG sanity check: | |
#assert np.alltrue(dis == np.array(zip(*np.unravel_index(assignments, n_bins)))) | |
return assignments, dis, n_bins | |
x = np.random.randn(10,3) | |
ass, dis, n_bins = regular_multidim_digitize(x, n_bins=5) | |
print np.alltrue(dis == unravel_index(ass, n_bins)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks, I don't see any disadvantage to your suggestion.
I've used the same trick as daien in the past but usually with integers in mind, even if it wasn't restricted to them.