Created
February 1, 2010 20:57
-
-
Save dwf/292022 to your computer and use it in GitHub Desktop.
Singhal & Wu's four regions benchmark for nonlinear classification algorithms, in Python/NumPy.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
NumPy implementation of the classic 'four regions' benchmark. | |
By David Warde-Farley -- user AT cs dot toronto dot edu (user = dwf) | |
Redistributable under the terms of the 3-clause BSD license | |
(see http://www.opensource.org/licenses/bsd-license.php for details) | |
""" | |
import numpy as np | |
def fourregions_labels(points): | |
""" | |
Returns labels for points in [-1,1]^2 from the "four regions" benchmark | |
classification task first described by Singhal and Wu. | |
'points' is a Nx2 rank-2 array (or numpy.matrix) | |
For more information see: | |
S. Singhal and L. Wu, "Training multilayer perceptrons with the | |
extended Kalman algorithm". Advances in Neural Information | |
Processing Systems, Proceedings of the 1988 Conference, pp133-140. | |
http://books.nips.cc/papers/files/nips01/0133.pdf | |
""" | |
region = np.zeros(points.shape[0]) | |
tophalf = points[:, 1] > 0 | |
righthalf = points[:, 0] > 0 | |
dists = np.sqrt(np.sum(points**2, axis=1)) | |
# The easy ones -- the outer shelf. | |
region[dists > np.sqrt(2)] = np.nan | |
outer = dists > 5./6. | |
region[np.logical_and(tophalf, outer)] = 3 | |
region[np.logical_and(np.logical_not(tophalf), outer)] = 4 | |
firstring = np.logical_and(dists > 1./6., dists <= 1./2.) | |
secondring = np.logical_and(dists > 1./2., dists <= 5./6.) | |
# Region 2 -- right inner and left outer, excluding center nut | |
region[np.logical_and(firstring, righthalf)] = 2 | |
region[np.logical_and(secondring, np.logical_not(righthalf))] = 2 | |
# Region 1 -- left inner and right outer, including center nut | |
region[np.logical_and(secondring, righthalf)] = 1 | |
region[np.logical_and(np.logical_not(righthalf), dists < 1./2.)] = 1 | |
region[np.logical_and(righthalf, dists < 1./6.)] = 1 | |
assert(np.all(region > 0)) | |
return region | |
def demo(): | |
"""Run a little demo, if matplotlib is available.""" | |
import matplotlib.pyplot as plt | |
x = np.random.rand(90000, 2) * 2 - 1 | |
plt.scatter(x[:, 0], x[:, 1], 10, fourregions_labels(x), cmap='gray') | |
plt.axis('equal') | |
plt.title('90000 samples from the four regions benchmark') | |
plt.show() | |
if __name__ == "__main__": | |
#try: | |
# demo() | |
#except: | |
# pass | |
demo() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment