Created
February 23, 2016 23:51
-
-
Save yamaguchiyuto/aa64eedfa72fef2f0dab to your computer and use it in GitHub Desktop.
Reproducing TransE experiments [NIPS'13]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import pickle | |
import numpy as np | |
import pandas as pd | |
from transe import TRANSE | |
modelfilepath = sys.argv[1] | |
h = sys.argv[2] | |
r = sys.argv[3] | |
n = int(sys.argv[4]) | |
with open(modelfilepath) as f: | |
transe,E,L = pickle.load(f) | |
dsim = transe.predict(E[h],L[r]) # predict t | |
top_n = np.argpartition(dsim,n)[:n] # non-sorted top n | |
sorted_top_n = top_n[dsim[top_n].argsort()] | |
values = dsim[sorted_top_n] | |
revE = {value:key for key,value in E.items()} | |
for i in range(sorted_top_n.shape[0]): | |
print "%s\t%s" % (revE[sorted_top_n[i]],values[i]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import pickle | |
import numpy as np | |
from transe import TRANSE | |
def load(filepath): | |
S = set() | |
for line in open(filepath): | |
h,l,t = line.strip().split('\t') | |
S.add((h,l,t)) | |
return S | |
modelfilepath = sys.argv[1] | |
testfilepath = sys.argv[2] | |
with open(modelfilepath) as f: | |
transe,E,L = pickle.load(f) | |
S = load(testfilepath) | |
meanrank = 0. | |
hit10 = 0. | |
for h,l,t in S: | |
dsim = transe.predict(E[h],L[l]) | |
meanrank += (dsim<dsim[E[t]]).sum() | |
if E[t] in np.argpartition(dsim,10)[:10]: | |
hit10 += 1 | |
meanrank /= len(S) | |
hit10 /= len(S) | |
print "MeanRank: %s" % meanrank | |
print "Hit10: %s" % hit10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import pickle | |
import numpy as np | |
from transe import TRANSE | |
def load(filepath,E=None,L=None): | |
X = [] | |
E = {} | |
L = {} | |
i = 0 | |
j = 0 | |
for line in open(filepath): | |
h,l,t = line.strip().split('\t') | |
if not h in E: | |
E[h] = i | |
i += 1 | |
if not t in E: | |
E[t] = i | |
i += 1 | |
if not l in L: | |
L[l] = j | |
j += 1 | |
X.append((E[h],L[l],E[t])) | |
return (np.array(X),E,L) | |
def load2(filepath,E,L): | |
X = [] | |
for line in open(filepath): | |
h,l,t = line.strip().split('\t') | |
X.append((E[h],L[l],E[t])) | |
return np.array(X) | |
trainingfilepath = sys.argv[1] | |
validationfilepath = sys.argv[2] | |
""" Params """ | |
r = 1 | |
k = 50 | |
lamb = 0.01 | |
b = 5000 | |
d = 'l1' | |
nepochs = 1000 | |
X,E,L = load(trainingfilepath) | |
V = load2(validationfilepath,E,L) | |
transe = TRANSE(len(E),len(L),r,k,lamb,b,d) | |
transe.fit(X,nepochs=nepochs,validationset=V) | |
with open('transe.model', 'w') as f: | |
pickle.dump((transe,E,L), f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.preprocessing import normalize | |
class TRANSE: | |
def __init__(self,n,m,r,k,lamb,b,d): | |
self.n = n # no. of entities | |
self.m = m # no. of relationships | |
self.r = r # margin | |
self.k = k # no. of dimensions | |
self.lamb = lamb # learning rate | |
self.b = b # size of minibatch | |
self.d = d # distance measure ('l1' or 'l2') | |
def meanrank(self,V,nsamples=1000): | |
ret = 0. | |
k = 0 | |
for h,r,t in V: | |
dsim = self.predict(h,r) | |
ret += (dsim<dsim[t]).sum() | |
k += 1 | |
if k >= nsamples: break | |
return ret/k | |
def get_batches(self,S,b): | |
for i in range(S.shape[0]/b-1): | |
yield S[i*b:(i+1)*b] | |
yield S[(S.shape[0]/b-1)*b:] | |
def negative_sampling(self,batch): | |
r = np.random.randint(2, size=batch.shape[0]) | |
e = np.random.randint(self.n, size=batch.shape[0]) | |
return np.vstack([batch.T,r,e]).T | |
def evaluate_grad(self,batch): | |
grad = {} | |
grad['e'] = np.zeros((self.n,self.k)) | |
grad['l'] = np.zeros((self.m,self.k)) | |
count = {} | |
count['e'] = np.zeros(self.n) | |
count['l'] = np.zeros(self.m) | |
T = self.negative_sampling(batch) # unif | |
for h,r,t,rnd,e in T: | |
if rnd == 1: h2 = h; t2 = e | |
else: h2 = e; t2 = t | |
if self.f(h,r,t) + self.r - self.f(h2,r,t2) > 0: | |
g1 = self.grad_f(h,r,t) | |
g2 = self.grad_f(h2,r,t2) | |
grad['e'][h] += g1 | |
grad['e'][t] += -g1 | |
grad['e'][h2] += -g2 | |
grad['e'][t2] += g2 | |
grad['l'][r] += g1-g2 | |
count['e'][h] += 1 | |
count['e'][t] += 1 | |
count['e'][h2] += 1 | |
count['e'][t2] += 1 | |
count['l'][r] += 1 | |
count['e'][count['e']==0] = 1 # avoid division by zero | |
count['l'][count['l']==0] = 1 # avoid division by zero | |
grad['e'] = (grad['e'].T/count['e']).T # 1/n | |
grad['l'] = (grad['l'].T/count['l']).T # 1/n | |
return grad | |
def grad_f(self,h,r,t): | |
if self.d == 'l1': | |
return (self.params['e'][h]+self.params['l'][r]-self.params['e'][t])/np.abs(self.params['e'][h]+self.params['l'][r]-self.params['e'][t]) | |
elif self.d == 'l2': | |
return (self.params['e'][h]+self.params['l'][r]-self.params['e'][t])/np.linalg.norm(self.params['e'][h]+self.params['l'][r]-self.params['e'][t]) | |
def f(self,h,r,t): | |
if self.d == 'l1': | |
return np.abs(self.params['e'][h]+self.params['l'][r]-self.params['e'][t]).sum() | |
elif self.d == 'l2': | |
return np.linalg.norm(self.params['e'][h]+self.params['l'][r]-self.params['e'][t]) | |
def init_params(self): | |
e = np.random.uniform(-6./self.k, 6./self.k, (self.n,self.k)) | |
l = normalize(np.random.uniform(-6./self.k, 6./self.k, (self.m,self.k))) | |
params = {} | |
params['e']=e | |
params['l']=l | |
return params | |
def fit(self,S,nepochs=1000,validationset=None): | |
self.params = self.init_params() | |
for i in range(nepochs): | |
np.random.shuffle(S) | |
for batch in self.get_batches(S,self.b): | |
self.params['e'] = normalize(self.params['e']) | |
grad = self.evaluate_grad(batch) | |
self.params['e'] -= self.lamb * grad['e'] | |
self.params['l'] -= self.lamb * grad['l'] | |
if not validationset == None: print i,self.meanrank(validationset) | |
return self | |
def predict(self,h,r): | |
q = self.params['e'][h] + self.params['l'][r] | |
if self.d == 'l1': | |
return np.abs(q-self.params['e']).sum(axis=1) | |
elif self.d == 'l2': | |
return np.linalg.norm(q-self.params['e'], axis=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment