Skip to content

Instantly share code, notes, and snippets.

@yamaguchiyuto
Created February 14, 2016 01:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yamaguchiyuto/3303398a61aa00f5707c to your computer and use it in GitHub Desktop.
Save yamaguchiyuto/3303398a61aa00f5707c to your computer and use it in GitHub Desktop.
import numpy as np
from scipy import linalg,sparse,random
class RESCAL:
def __init__(self,r,lamb_A,lamb_R):
self.r = r
self.lamb_A = lamb_A
self.lamb_R = lamb_R
def fit(self,X,niter=30):
m = len(X)
n,_ = X[0].shape
self.A = random.randn(n,self.r)
self.R = [random.randn(self.r,self.r) for i in range(m)]
t = 0
while True:
""" update A """
AA = self.A.T.dot(self.A)
F = sum([X[k].dot(self.A).dot(self.R[k].T) + X[k].T.dot(self.A).dot(self.R[k]) for k in range(m)])
S = sum([self.R[k].dot(AA).dot(self.R[k].T) + self.R[k].T.dot(AA).dot(self.R[k]) for k in range(m)])
S += m * self.lamb_A * np.identity(self.r)
self.A = F.dot(linalg.inv(S))
""" update R """
Q,A_bar = linalg.qr(self.A,mode='economic')
Z = sparse.kron(A_bar,A_bar)
for k in range(m):
vec_Xk = Q.T.dot(X[k].dot(Q)).reshape(self.r**2,1)
self.R[k] = linalg.inv(Z.T.dot(Z) + self.lamb_R*np.identity(self.r**2)).dot(Z.T.dot(vec_Xk)).reshape(self.r,self.r)
t += 1
if t >= niter: break
if __name__ == '__main__':
# Example graph from ICML'11 paper
m = 2 # number of edge types
X = [sparse.lil_matrix((5,5)) for i in range(m)]
X[0][0,1] = 1 # vicePresidentOf
X[0][2,3] = 1 # vicePresidentOf
X[1][0,4] = 1 # party
X[1][1,4] = 1 # party
X[1][2,4] = 1 # party
nodenames = {0:'Lyndon', 1:'John', 2:'AI', 3:'Bill', 4:'Party X'}
edgetypes = {0:'vicePresidentOf', 1:'party'}
# Parameters
r = 3 # number of latent component
lamb_A = 0.00001 # regularization
lamb_R = 0.00001 # regularization
rescal = RESCAL(r,lamb_A,lamb_R)
rescal.fit(X)
# TEST (link prediction)
X_bar = [rescal.A.dot(rescal.R[i]).dot(rescal.A.T) for i in range(m)] # reconstract X
estimated_facts = []
for k in range(m):
indices = np.where(X_bar[k]>0.1) # triples with high likelihood
for i in range(indices[0].shape[0]):
if X[k][indices[0][i],indices[1][i]] == 0: # only for non-existence triples
estimated_facts.append((nodenames[indices[0][i]], edgetypes[k], nodenames[indices[1][i]]))
for f in estimated_facts:
print f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment