Skip to content

Instantly share code, notes, and snippets.

@denis-bz
Created January 21, 2014 17:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denis-bz/8544202 to your computer and use it in GitHub Desktop.
Save denis-bz/8544202 to your computer and use it in GitHub Desktop.
nearsame.py: an easy and intuitive way to find outliers in data 2d to 20d, with kd-trees
#!/usr/bin/env python
""" nearsame.py: an easy and intuitive way to find outliers in data 2d to 20d.
-> e.g.
[50 24 15 10] % of the 38356 points have
[ 3 2 1 0] of the 3 nearest neighbors in the same class.
Here one could keep the 74 % for which the nearest 3 neighbors
are mostly in the same class, and call the rest outliers.
Intuitively, points in mixed neighborhoods, where classes overlap,
are harder to classify. (For "classes" think colors,
e.g. a Go board: black white or empty.)
A possible flow, using kd-trees:
1) build a tree from the reference or training data
2) trim outliers with `nearsame`, e.g.
keep = (ndiff <= 1)
X = X[keep]
y = y[keep]
3) build a new trimmed tree, and use it for classifying new data points
by majority vote of say 3 or 5 nearest neighbors.
If there's no clear majority, return "not sure".
(We have to build a new tree because most kd-tree implementation
cannot delete points easily.)
How coordinates are scaled -- the metric for "nearby" --
is of course important. Try normalizing all rows of X to 1,
i.e. cos metric; then |x - y|^2 = 2 - 2 x . y
See also:
Hastie et al., Elements of Stat Learning p. 118: 11 vowels, 90 samples each
heed hid head had hard hud hod hoard hood who'd heard
keywords: multiclass high-dimensional outlier kd-tree
"""
# https://www.google.com/search?as_q=high-dimensional+outlier
# http://scholar.google.de/scholar?as_q=high-dimensional+outlier&as_occt=title&as_ylo=2010
from __future__ import division
import numpy as np
from scipy.spatial import cKDTree as KDTree
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.KDTree.html
__version__ = "2014-01-21 Jan denis"
#...............................................................................
def nearsame( X, y, nnear=3, leafsize=10, p=2, verbose=1 ):
""" in: X: ndata x ndim reals, ndim 2 .. 20
y: class numbers, 0 1 2 ... for e.g. red green blue ...
out: ndiff, ix
ndiff: ndata int8s, the number of neighbors in a different class
0 all same .. nnear all diff
ix: ndata x nnear+1 indices 0 .. ndata, from kdtree
"""
X = np.asanyarray( X )
y = np.asanyarray( y )
ndata, ndim = X.shape
assert y.dtype.kind == "i", "y must be int, not %s" % y.dtype
if ndim > 20:
print "warning: nearsame %s may be slow, try KDTree cutoff" % str(X.shape)
# flann: 128d
tree = KDTree( X, leafsize=leafsize ) # build the tree
distances, ix = tree.query( X, k=nnear+1, p=p )
del distances # idw nominals ?
assert np.all( ix[:,0] == np.arange( ndata )) # ?
nearclasses = y[ix[:,1:]] # of nnear nearest neighbors
ndiff = (y != nearclasses.T) .sum(axis=0) .astype(np.int8) # 0 .. nnear
if verbose:
counts = np.bincount(ndiff)
percents = (counts * 100 / counts.sum()) .round().astype(int)
print """
%s %% of the %s points have
%s of the %d nearest neighbors in the same class. """ % (
percents, X.shape, np.arange( nnear, -1, -1 ), nnear )
print "# av ndiff of each class, high => harder to classify:" # v roughly
for j in range( y.max() + 1 ):
ndiffj = ndiff[ y == j ]
if len(ndiffj) > 0:
print "%d %.2g " % (j, ndiffj.mean())
print "\n"
return ndiff, ix
def subset_arrays( ndiff, adict, near=1 ):
""" X = X[ndiff <= near] for X in adict / np.load """
if isinstance( ndiff, basestring ):
ndiff = np.loadtxt( ndiff, dtype=np.int8 )
Jnear = (ndiff <= near) .nonzero()[0]
print "subset_arrays: %d of %d are <= %d:" % (
len(Jnear), len(ndiff), near) ,
for k, v in sorted( adict.items() ):
if getattr( v, "shape", None ) and len(v) == len(ndiff):
# np.array( 3 "str" [] ... ) in np.load
adict[k] = v[Jnear]
print k ,
print ""
adict["ndiff"] = ndiff
#...............................................................................
if __name__ == "__main__":
import sys
from bz.etc import dataxy
source = "statlearn/vowel*" # Hastie, Elements of Stat Learning p. 118
centre = 4 # 2: -= mean /= sd, 4: row norms 1
nnear = 3
p = 2 # inf faster than 2
# run this.py a=1 b=None c=[3] 'd = expr' ... in sh or ipython
exec( "\n".join( sys.argv[1:] ))
np.set_printoptions( 1, threshold=20, edgeitems=10, linewidth=200, suppress=True )
#...............................................................................
bag = dataxy.dataxy( source, centre=centre )
X, y = bag.X, bag.y
ndiff, ix = nearsame( X, y, nnear=nnear, p=p )
# vowels av ndiff of each class:
# 1 .033 2 .089 3 .14 4 .14 5 .39 6 .52 7 .49 8 .14 9 .43 10 .089 11 .2
# heed hid head had hard hud hod hoard hood who'd heard
subset_arrays( ndiff, bag ) # 936 of 990 are <= 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment