Skip to content

Instantly share code, notes, and snippets.

@syhw
Last active December 22, 2015 07:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syhw/6436493 to your computer and use it in GitHub Desktop.
Save syhw/6436493 to your computer and use it in GitHub Desktop.
Encoding a checkerboard with an RBM
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neural_network import BernoulliRBM
from sklearn import linear_model, metrics
from sklearn.pipeline import Pipeline
X = np.array([[0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0],
[1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1,
1,0,1,0,1,0,1,0,
0,1,0,1,0,1,0,1]])
X = np.r_[X,X,X,X,X]
X_test = np.r_[X, np.array([np.floor(np.random.random(64) + 0.5) for i in xrange(10)])]
X = np.r_[X, np.array([np.floor(np.random.random(64) + 0.5) for i in xrange(10)])]
Y = np.array([1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0])
rbm = BernoulliRBM(random_state=0, verbose=True)
rbm.learning_rate = 0.01
rbm.n_iter = 2000
rbm.n_components = 4
#rbm.fit(X)
logistic = linear_model.LogisticRegression()
logistic.C = 10.0
classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])
classifier.fit(X, Y)
print ("RBM+logistic %s:" % (metrics.classification_report(Y, classifier.predict(X_test))))
plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(rbm.components_):
plt.subplot(2, 2, i + 1)
plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.suptitle('64 components extracted by RBM', fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment