Skip to content

Instantly share code, notes, and snippets.

@arjun180
Created April 4, 2016 14:36
Show Gist options
  • Save arjun180/71124392b0b70f7b96a8826b59400b99 to your computer and use it in GitHub Desktop.
Save arjun180/71124392b0b70f7b96a8826b59400b99 to your computer and use it in GitHub Desktop.
Implementation of ALS
def weighted_alternating():
Q=[[5,3,0,1],[4,0,0,1],[1,1,0,5],[1,0,0,4],[0,1,5,4]]
Q = np.array(Q)
# Create the weight matrix
W = Q>0
W[W == True] = 1
W[W == False] = 0
# To be consistent with our Q matrix
W = W.astype(np.float64, copy=False)
lambda_ = 0.1
n_factors = 2
m, n = Q.shape
n_iterations = 20
X = 5 * np.random.rand(m, n_factors)
Y = 5 * np.random.rand(n_factors, n)
weighted_errors =[]
start = time.time()
for ii in range(n_iterations):
for u,Wu in enumerate(W):
X[u] = np.linalg.solve(np.dot(Y, np.dot(np.diag(Wu), Y.T)) + lambda_ * np.eye(n_factors),np.dot(Y, np.dot(np.diag(Wu), Q[u].T))).T
for i, Wi in enumerate(W.T):
Y[:,i] = np.linalg.solve(np.dot(X.T, np.dot(np.diag(Wi), X))+ lambda_ * np.eye(n_factors), np.dot(X.T, np.dot(np.diag(Wi), Q[:, i])))
error = np.sum((W * (Q - np.dot(X, Y)))**2)
print error
weighted_errors.append(error)
elapsed = (start-time.time())
Q_hat = np.dot(X,Y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment