Instantly share code, notes, and snippets.

# yamaguchiyuto/rescal.py Created Feb 14, 2016

What would you like to do?
 import numpy as np from scipy import linalg,sparse,random class RESCAL: def __init__(self,r,lamb_A,lamb_R): self.r = r self.lamb_A = lamb_A self.lamb_R = lamb_R def fit(self,X,niter=30): m = len(X) n,_ = X[0].shape self.A = random.randn(n,self.r) self.R = [random.randn(self.r,self.r) for i in range(m)] t = 0 while True: """ update A """ AA = self.A.T.dot(self.A) F = sum([X[k].dot(self.A).dot(self.R[k].T) + X[k].T.dot(self.A).dot(self.R[k]) for k in range(m)]) S = sum([self.R[k].dot(AA).dot(self.R[k].T) + self.R[k].T.dot(AA).dot(self.R[k]) for k in range(m)]) S += m * self.lamb_A * np.identity(self.r) self.A = F.dot(linalg.inv(S)) """ update R """ Q,A_bar = linalg.qr(self.A,mode='economic') Z = sparse.kron(A_bar,A_bar) for k in range(m): vec_Xk = Q.T.dot(X[k].dot(Q)).reshape(self.r**2,1) self.R[k] = linalg.inv(Z.T.dot(Z) + self.lamb_R*np.identity(self.r**2)).dot(Z.T.dot(vec_Xk)).reshape(self.r,self.r) t += 1 if t >= niter: break if __name__ == '__main__': # Example graph from ICML'11 paper m = 2 # number of edge types X = [sparse.lil_matrix((5,5)) for i in range(m)] X[0][0,1] = 1 # vicePresidentOf X[0][2,3] = 1 # vicePresidentOf X[1][0,4] = 1 # party X[1][1,4] = 1 # party X[1][2,4] = 1 # party nodenames = {0:'Lyndon', 1:'John', 2:'AI', 3:'Bill', 4:'Party X'} edgetypes = {0:'vicePresidentOf', 1:'party'} # Parameters r = 3 # number of latent component lamb_A = 0.00001 # regularization lamb_R = 0.00001 # regularization rescal = RESCAL(r,lamb_A,lamb_R) rescal.fit(X) # TEST (link prediction) X_bar = [rescal.A.dot(rescal.R[i]).dot(rescal.A.T) for i in range(m)] # reconstract X estimated_facts = [] for k in range(m): indices = np.where(X_bar[k]>0.1) # triples with high likelihood for i in range(indices[0].shape[0]): if X[k][indices[0][i],indices[1][i]] == 0: # only for non-existence triples estimated_facts.append((nodenames[indices[0][i]], edgetypes[k], nodenames[indices[1][i]])) for f in estimated_facts: print f