Skip to content

Instantly share code, notes, and snippets.

@beomjunshin-ben
Created May 13, 2015 15:57
Show Gist options
  • Save beomjunshin-ben/733f03a6bcf7151dbb67 to your computer and use it in GitHub Desktop.
Save beomjunshin-ben/733f03a6bcf7151dbb67 to your computer and use it in GitHub Desktop.
RBM
from sklearn.datasets import fetch_mldata
import numpy as np
import matplotlib.pyplot as plt
class RBM:
def __init__(self, num_visible, num_hidden, learning_rate=0.1):
self.num_hidden = num_hidden
self.num_visible = num_visible
self.learning_rate = learning_rate
self.w = 0.1 * np.random.randn(self.num_visible, self.num_hidden)
self.b = 0.1 * np.random.randn(self.num_hidden)
self.c = 0.1 * np.random.randn(self.num_visible)
def train(self, data_all, batch_size=100, max_epochs=1000):
num_data = data_all.shape[0]
num_batch = num_data/batch_size
for epoch in range(max_epochs):
for batch_idx in range(batch_size):
data = data_all[batch_idx*num_batch:(batch_idx+1)*num_batch]
# add bias term b row-wise
# pos_hidden_activations is row-wise stack of hidden units of each training sample
pos_hidden_activations = self.sigmoid(self.b + np.dot(data, self.w)) # (70000 x 576)
pos_phase_w = np.dot(data.T, pos_hidden_activations) # (784 x 70000) x (70000 x 576) = (dimension of self.w)
neg_visible_activations = self.sigmoid(self.c + np.dot(pos_hidden_activations, self.w.T))
neg_visible_states = neg_visible_activations > np.random.rand(num_batch, self.num_visible)
neg_hidden_activations = self.sigmoid(self.b + np.dot(neg_visible_states, self.w))
neg_phase_w = np.dot(neg_visible_states.T, neg_hidden_activations)
self.w += self.learning_rate*(pos_phase_w - neg_phase_w) / num_batch
# image tile for self.w[:, 0~num_hidden-1]
error = np.sum((data - neg_visible_states) ** 2)
visualize(self.w[:, 0], width=28, height=28)
print("Epoch %s: error is %s" % (epoch, error))
print self.w
def sigmoid(self, x):
return 1/(np.exp(-x) + 1)
def visualize(x, width=-1, height=-1, pad=5, title='Filter.png'):
# TODO save tiled images.. tired
x = x.reshape((width, height))
f = plt.figure()
plt.imshow(x, cmap='Greys')
f.savefig(title)
if __name__ == '__main__':
mnist = fetch_mldata('MNIST original', data_home='./')
X = mnist.data
r = RBM(num_visible=28*28, num_hidden=10)
r.train(X, batch_size=100, max_epochs=5000)
np.savetxt('weights.out', r.w)
np.savetxt('bias_b.out', r.b)
np.savetxt('bias_c.out', r.c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment