Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import numpy as np
class Tucker:
def __init__(self,R,S,T,max_iter):
self.latent_size = (R,S,T)
self.max_iter=max_iter
def _calc_data_shape(self,X):
max_i = -1
max_j = -1
max_k = -1
for i,j,k in X:
if max_i<i: max_i=i
if max_j<j: max_j=j
if max_k<k: max_k=k
return (max_i+1,max_j+1,max_k+1)
def _calc_core_tensor(self,X):
G = np.zeros(shape=self.latent_size)
for a in range(self.latent_size[0]):
for b in range(self.latent_size[1]):
for c in range(self.latent_size[2]):
for indices in X:
i,j,k = indices
G[a,b,c] += X[indices]*self.A[0][i,a]*self.A[1][j,b]*self.A[2][k,c]
return G
def _init_latent_vectors(self):
A = {}
for m in range(3):
A[m] = np.random.normal(loc=0, scale=0.1, size=(self.data_shape[m],self.latent_size[m]))
return A
def _loss(self,X):
G = self._calc_core_tensor(X)
return - (G**2).sum() / (G.shape[0]*G.shape[1]*G.shape[2])
def _update(self,X,modes):
row = self.data_shape[modes[0]]
col = self.A[modes[1]].shape[1] * self.A[modes[2]].shape[1]
X_bar = np.zeros(shape=(row,col))
for indices in X:
for a in range(self.latent_size[modes[1]]):
for b in range(self.latent_size[modes[2]]):
q = a*self.latent_size[modes[2]]+b
X_bar[indices[modes[0]],q] += X[indices] * self.A[modes[1]][indices[modes[1]],a] * self.A[modes[2]][indices[modes[2]],b]
U,s,V = np.linalg.svd(X_bar)
return U[:,:self.latent_size[modes[0]]]
def fit(self,X):
self.data_shape = self._calc_data_shape(X)
self.A = self._init_latent_vectors()
self._losses = []
remained_iter = self.max_iter
while remained_iter>0:
self.A[0] = self._update(X,(0,1,2))
self.A[1] = self._update(X,(1,2,0))
self.A[2] = self._update(X,(2,0,1))
remained_iter-=1
l = self._loss(X)
self._losses.append(l)
self.G = self._calc_core_tensor(X)
return self
def predict(self,indices):
i,j,k = indices
ret = np.tensordot(
np.tensordot(
np.tensordot(self.G,
self.A[2][k], axes=(2,0)),
self.A[1][j], axes=(1,0)),
self.A[0][i], axes=(0,0))
return ret
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment