Skip to content

Instantly share code, notes, and snippets.

@daien
Created December 22, 2011 10:34
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save daien/1509853 to your computer and use it in GitHub Desktop.
Save daien/1509853 to your computer and use it in GitHub Desktop.
Build a regular grid and assigns each data point to a cell
import numpy as np
def regular_multidim_digitize(data, n_bins=3, lbes=None, rbes=None):
""" Build a regular grid and assigns each data point to a cell
Parameters
----------
data: (n_points, n_dims) array,
the data that we which to digitize
n_bins: int or (n_dims, ) array-like, optional, default: 3,
per-dimension number of bins
lbes: float or (n_dims, ) array-like, optional, default: None,
per-dimension left-most bin edges
by default, use the min along each dimension
rbes: float or (n_dims, ) array-like, optional, default: None,
per-dimension right-most bin edges
by default, use the max along each dimension
Returns
-------
assignments: (n_points, ) array of integers,
the assignment index of each point
Notes
-----
"Regular" here means evenly-spaced along each dimension (possibly separately)
Each cell has its unique integer index which is the cell number in C-order
(last dimension varies fastest).
To obtain the cell coordinates from the cell numbers do:
np.array(zip(*np.unravel_index(assignments, n_bins)))
"""
n_points, n_dims = data.shape
if isinstance(n_bins, int):
n_bins = np.array([n_bins for i in range(n_dims)], dtype=np.int)
else:
n_bins = np.asarray(n_bins, dtype=np.int)
assert n_bins.shape == (n_dims,), "invalid n_bins: {0}".format(n_bins)
if lbes is None:
lbes = np.min(data, axis=0).astype(np.float)
else:
if isinstance(lbes, float):
lbes = np.array([lbes for i in range(n_dims)], dtype=np.float)
else:
lbes = np.asarray(lbes, dtype=np.float)
assert len(lbes) == n_dims, "Invalid lbes: {0}".format(lbes)
# check for overflow
assert np.alltrue(lbes <= np.min(data, axis=0)), "lbes not low enough"
if rbes is None:
rbes = np.max(data, axis=0).astype(np.float)
else:
if isinstance(rbes, float):
rbes = np.array([rbes for i in range(n_dims)], dtype=np.float)
else:
rbes = np.asarray(rbes, dtype=np.float)
assert len(rbes) == n_dims, "Invalid rbes: {0}".format(rbes)
# check for overflow
assert np.alltrue(rbes >= np.max(data, axis=0)), "rbes not high enough"
# get the bin-widths per dimension
bws = (1. + 1e-15) * (rbes - lbes) / n_bins # add small shift to have max in last bin
# get the per-dim bin multi-dim index of each point
dis = ((data - lbes[np.newaxis, :]) / bws[np.newaxis, :]).astype(np.int)
# get index of the flattened grid
assignments = np.ravel_multi_index(dis.T, n_bins)
# DEBUG sanity check:
#assert np.alltrue(dis == np.array(zip(*np.unravel_index(assignments, n_bins))))
return assignments
@josef-pkt
Copy link

taldcroft left a comment on my fork that you might also be interested in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment