Skip to content

Instantly share code, notes, and snippets.

@dwf
Created February 1, 2010 20:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dwf/292022 to your computer and use it in GitHub Desktop.
Save dwf/292022 to your computer and use it in GitHub Desktop.
Singhal & Wu's four regions benchmark for nonlinear classification algorithms, in Python/NumPy.
#!/usr/bin/env python
"""
NumPy implementation of the classic 'four regions' benchmark.
By David Warde-Farley -- user AT cs dot toronto dot edu (user = dwf)
Redistributable under the terms of the 3-clause BSD license
(see http://www.opensource.org/licenses/bsd-license.php for details)
"""
import numpy as np
def fourregions_labels(points):
"""
Returns labels for points in [-1,1]^2 from the "four regions" benchmark
classification task first described by Singhal and Wu.
'points' is a Nx2 rank-2 array (or numpy.matrix)
For more information see:
S. Singhal and L. Wu, "Training multilayer perceptrons with the
extended Kalman algorithm". Advances in Neural Information
Processing Systems, Proceedings of the 1988 Conference, pp133-140.
http://books.nips.cc/papers/files/nips01/0133.pdf
"""
region = np.zeros(points.shape[0])
tophalf = points[:, 1] > 0
righthalf = points[:, 0] > 0
dists = np.sqrt(np.sum(points**2, axis=1))
# The easy ones -- the outer shelf.
region[dists > np.sqrt(2)] = np.nan
outer = dists > 5./6.
region[np.logical_and(tophalf, outer)] = 3
region[np.logical_and(np.logical_not(tophalf), outer)] = 4
firstring = np.logical_and(dists > 1./6., dists <= 1./2.)
secondring = np.logical_and(dists > 1./2., dists <= 5./6.)
# Region 2 -- right inner and left outer, excluding center nut
region[np.logical_and(firstring, righthalf)] = 2
region[np.logical_and(secondring, np.logical_not(righthalf))] = 2
# Region 1 -- left inner and right outer, including center nut
region[np.logical_and(secondring, righthalf)] = 1
region[np.logical_and(np.logical_not(righthalf), dists < 1./2.)] = 1
region[np.logical_and(righthalf, dists < 1./6.)] = 1
assert(np.all(region > 0))
return region
def demo():
"""Run a little demo, if matplotlib is available."""
import matplotlib.pyplot as plt
x = np.random.rand(90000, 2) * 2 - 1
plt.scatter(x[:, 0], x[:, 1], 10, fourregions_labels(x), cmap='gray')
plt.axis('equal')
plt.title('90000 samples from the four regions benchmark')
plt.show()
if __name__ == "__main__":
#try:
# demo()
#except:
# pass
demo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment