Skip to content

Instantly share code, notes, and snippets.

@EthanRosenthal
Created March 13, 2016 20:38
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save EthanRosenthal/a0816d8fea4394baf732 to your computer and use it in GitHub Desktop.
Save EthanRosenthal/a0816d8fea4394baf732 to your computer and use it in GitHub Desktop.
Class for ALS training of an explicit matrix factorization model
from numpy.linalg import solve
class ExplicitMF():
def __init__(self,
ratings,
n_factors=40,
item_reg=0.0,
user_reg=0.0,
verbose=False):
"""
Train a matrix factorization model to predict empty
entries in a matrix. The terminology assumes a
ratings matrix which is ~ user x item
Params
======
ratings : (ndarray)
User x Item matrix with corresponding ratings
n_factors : (int)
Number of latent factors to use in matrix
factorization model
item_reg : (float)
Regularization term for item latent factors
user_reg : (float)
Regularization term for user latent factors
verbose : (bool)
Whether or not to printout training progress
"""
self.ratings = ratings
self.n_users, self.n_items = ratings.shape
self.n_factors = n_factors
self.item_reg = item_reg
self.user_reg = user_reg
self._v = verbose
def als_step(self,
latent_vectors,
fixed_vecs,
ratings,
_lambda,
type='user'):
"""
One of the two ALS steps. Solve for the latent vectors
specified by type.
"""
if type == 'user':
# Precompute
YTY = fixed_vecs.T.dot(fixed_vecs)
lambdaI = np.eye(YTY.shape[0]) * _lambda
for u in xrange(latent_vectors.shape[0]):
latent_vectors[u, :] = solve((YTY + lambdaI),
ratings[u, :].dot(fixed_vecs))
elif type == 'item':
# Precompute
XTX = fixed_vecs.T.dot(fixed_vecs)
lambdaI = np.eye(XTX.shape[0]) * _lambda
for i in xrange(latent_vectors.shape[0]):
latent_vectors[i, :] = solve((XTX + lambdaI),
ratings[:, i].T.dot(fixed_vecs))
return latent_vectors
def train(self, n_iter=10):
""" Train model for n_iter iterations from scratch."""
# initialize latent vectors
self.user_vecs = np.random.random((self.n_users, self.n_factors))
self.item_vecs = np.random.random((self.n_items, self.n_factors))
self.partial_train(n_iter)
def partial_train(self, n_iter):
"""
Train model for n_iter iterations. Can be
called multiple times for further training.
"""
ctr = 1
while ctr <= n_iter:
if ctr % 10 == 0 and self._v:
print '\tcurrent iteration: {}'.format(ctr)
self.user_vecs = self.als_step(self.user_vecs,
self.item_vecs,
self.ratings,
self.user_reg,
type='user')
self.item_vecs = self.als_step(self.item_vecs,
self.user_vecs,
self.ratings,
self.item_reg,
type='item')
ctr += 1
def predict_all(self):
""" Predict ratings for every user and item. """
predictions = np.zeros((self.user_vecs.shape[0],
self.item_vecs.shape[0]))
for u in xrange(self.user_vecs.shape[0]):
for i in xrange(self.item_vecs.shape[0]):
predictions[u, i] = self.predict(u, i)
return predictions
def predict(self, u, i):
""" Single user and item prediction. """
return self.user_vecs[u, :].dot(self.item_vecs[i, :].T)
def calculate_learning_curve(self, iter_array, test):
"""
Keep track of MSE as a function of training iterations.
Params
======
iter_array : (list)
List of numbers of iterations to train for each step of
the learning curve. e.g. [1, 5, 10, 20]
test : (2D ndarray)
Testing dataset (assumed to be user x item).
The function creates two new class attributes:
train_mse : (list)
Training data MSE values for each value of iter_array
test_mse : (list)
Test data MSE values for each value of iter_array
"""
iter_array.sort()
self.train_mse =[]
self.test_mse = []
iter_diff = 0
for (i, n_iter) in enumerate(iter_array):
if self._v:
print 'Iteration: {}'.format(n_iter)
if i == 0:
self.train(n_iter - iter_diff)
else:
self.partial_train(n_iter - iter_diff)
predictions = self.predict_all()
self.train_mse += [get_mse(predictions, self.ratings)]
self.test_mse += [get_mse(predictions, test)]
if self._v:
print 'Train mse: ' + str(self.train_mse[-1])
print 'Test mse: ' + str(self.test_mse[-1])
iter_diff = n_iter
@PhanDuc
Copy link

PhanDuc commented Jul 21, 2016

Hi Ethan,

I had error in line 58 : unhashable type: 'slice'

How to fix that error ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment