Skip to content

Instantly share code, notes, and snippets.

@neubig
Created April 30, 2013 02:30
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save neubig/5486285 to your computer and use it in GitHub Desktop.
Save neubig/5486285 to your computer and use it in GitHub Desktop.
This is a an example of how to train a restricted Boltzmann machine language model
#!/usr/bin/python
# This code implements the training part of the Restricted Boltzmann Machine
# language model described by:
# Three New Graphical Models for Statistical Language Modeling
# Andriy Mnih and Geoffrey Hinton
# ICML 2007
# http://www.gatsby.ucl.ac.uk/~amnih/papers/threenew.pdf
#
# Usage: train-rbmlm.py training-file.txt
# There are various settings written in the source below that you can
# manipulate
#
# TODO: The current implementation does not update the biases
# TODO: No regularization, so optimization sometimes goes crazy
# TODO: Because this is not taking advantage of sparsity it is VERY SLOW
import math
import random
import sys
import numpy as np
from collections import defaultdict
# Settings
n=3
num_iters=100
N_f = 5
N_h = 10
alpha = 0.1
MEAN_FIELD = True
# An index into words
wid_dic = defaultdict(lambda: len(wid_dic))
start_id = wid_dic["<s>"]
end_id = wid_dic["</s>"]
# Load the n-grams into memory
input_file = open(sys.argv[1], "r")
ngrams = []
unigrams = []
for line in input_file:
line = line.strip()
# Create the sentence
sent = ([start_id]*(n-1) + # Starting context
[wid_dic[i] for i in line.split(" ")] + # Words
[end_id]) # Final symbol
# For all n-grams
for j in range(n-1, len(sent)):
ngrams.append( sent[j-n+1:j+1] )
while len(unigrams) <= sent[j]: unigrams.append(1e-100)
unigrams[sent[j]] += 1
# Initialize the matrices
N_w = len(wid_dic) # Number of words
W = [ (np.random.rand(N_f, N_h)-0.5)*0.01 for i in range(n) ] # One N_f x N_h matrix for each context position
R = (np.random.rand(N_w, N_f)-0.5)*0.01 # One N_w x N_f matrix of word representations
b_h = np.zeros( (N_h, 1) ) # Biases
b_r = np.zeros( (N_f, 1) )
unigrams = map( math.log, unigrams )
b_v = np.array( unigrams ).reshape( (N_w, 1) )
# Auxiliary functions for sampling
def sigmoid_samp(x):
if x > 500: return 1
elif x < -500: return 0
prob = 1 / (1 + math.exp(-x))
return 1 if random.random() < prob else 0
def softmax(w):
e = np.exp(w - np.max(w))
return e / np.sum(e)
def sample_one(probs):
left = random.random()*sum(probs)
for i, v in enumerate(probs):
left -= v
if left <= 0:
return i
raise Exception('Overflow in sample_one:\n%r\n%r' % (probs, sum(probs)))
# Do iterations
print len(ngrams)
for iter_num in range(num_iters):
num_words = 0
log_prob = 0
for ngram_1 in ngrams:
num_words += 1
if num_words % 100 == 0: print >> sys.stderr, num_words
###### Create Column Vectors ######
v_1 = [ np.zeros( (N_w, 1) ) for i in range(n) ]
for i in range(n):
v_1[i][ngram_1[i],0] = 1
###### Calculate P(h|w_{1:n}) (Equation 10) ######
h_1 = np.array(b_h.T)
for i in range(n):
h_1 += np.dot( np.dot(v_1[i].T,R), W[i] )
###### Sample from P(h|w_{1:n}) ######
for i in range(h_1.shape[0]):
for j in range(h_1.shape[1]):
h_1[i,j] = sigmoid_samp(h_1[i,j])
###### Calculate v_2 (Equation 9) ######
v_2 = [ v_1[i] for i in range(n) ] # First elements are same as v_1
v_2_n = softmax( np.dot((np.dot(h_1, W[n-1].T) + b_r.T), R.T) + b_v.T ).T
# Collect statistics
log_prob += math.log(v_2_n[ngram_1[n-1],0]) # Count the log probability
# Either add the mean field or
if not MEAN_FIELD:
v_2_n = np.zeros( (N_w, 1) )
v_2_n[w_2_n,0] = 1
v_2[n-1] = v_2_n
###### Calculate P(h|w_{1:n}) (Equation 10) ######
h_2 = np.array(b_h.T)
for i in range(n):
h_2 += np.dot( np.dot(v_1[i].T,R), W[i] )
###### Sample from P(h|w_{1:n}) ######
for i in range(h_2.shape[0]):
for j in range(h_2.shape[1]):
h_2[i,j] = sigmoid_samp(h_2[i,j])
###### Update W (Equation 8) #####
oldW = np.array(W)
for i in range(n):
W[i] += (np.dot( np.dot(R.T, v_1[i]), h_1*alpha ) -
np.dot( np.dot(R.T, v_2[i]), h_2*alpha ))
###### Update R (Equation 9) #####
R += np.dot(v_1[n-1], b_r.T*alpha)
R -= np.dot(v_2[n-1], b_r.T*alpha)
for i in range(n):
R += np.dot(np.dot(v_1[i], h_1*alpha), oldW[i].T)
R -= np.dot(np.dot(v_2[i], h_2*alpha), oldW[i].T)
print >> sys.stderr, "Iteration %r: log_prob=%r, PPL=%r" % (iter_num, log_prob, math.exp(-log_prob/num_words))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment