Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Created June 14, 2018 19:15
Show Gist options
  • Save KeremTurgutlu/d7303fcf5ff72156c77956149764f9ea to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/d7303fcf5ff72156c77956149764f9ea to your computer and use it in GitHub Desktop.
def construct(A,B,C):
"""
Given Matrices A, B, C construct 3D Tensor
A : i, r
B: j, r
C : k, r
"""
X_tilde = 0
r = A.shape[1]
for i in range(r):
X_tilde += torch.ger(A[:,i], B[:,i]).unsqueeze(2)*C[:,i].unsqueeze(0).unsqueeze(0)
return X_tilde
def CP_decomposition(factors=[A,B,C], max_iter=10000, lr=0.1):
"""
Minimize Frobenius Norm |X-X_tilde|
Update decomposition factors
"""
opt = Adam(factors, lr=lr)
losses = []
for i in range(max_iter):
X_tilde = construct(*factors)
opt.zero_grad()
loss = torch.mean((X - X_tilde)**2)
#print(loss)
losses.append(loss.item())
loss.backward(retain_graph=True)
opt.step()
return losses
r = 1
A = torch.randn((3,r), requires_grad=True)
B = torch.randn((4,r), requires_grad=True)
C = torch.randn((5,r), requires_grad=True)
rank1_loss = CP_decomposition([A,B,C])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment