Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active March 28, 2021 00:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/12142f4f2b8ad1206880c9238e1af29a to your computer and use it in GitHub Desktop.
Save ahwillia/12142f4f2b8ad1206880c9238e1af29a to your computer and use it in GitHub Desktop.
Simple Nonnegative Matrix Factorization in Pytorch
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch_nonneg_linesearch import nonneg_projected_gradient_step
# Data dimensions
m, n = 100, 101
rank = 3
# Data matrix, detached from the graph.
X = torch.rand(m, rank) @ torch.rand(rank, n)
# Initialize factor matrices.
W = torch.rand(rank, m, requires_grad=True)
H = torch.rand(rank, n, requires_grad=True)
# Setup optimization
losses = []
num_iters = 100
inner_iters = 2
learning_rate_multiplier = 2.0 # Set near 1.0
# Define loss for H parameter.
def loss_H():
return (
-2 * torch.sum(WX * H) +
torch.sum(WWt * (H @ H.T))
)
# Define loss for W parameter.
def loss_W():
return (
-2 * torch.sum(W.T * XHt) +
torch.sum(HHt * (W @ W.T))
)
# Holds copies of the params, detached from graph.
W_ = torch.empty_like(W)
H_ = torch.empty_like(H)
# Main loop.
for itr in range(num_iters):
# === UPDATE H === #
# Cached matrix products. Note that these don't require
# gradients, so we detach W from the graph.
W_.copy_(W.data)
WWt = W_ @ W_.T
WX = W_ @ X
# This is an upper bound on the Lipshitz constant for the
# objective function.
learning_rate = learning_rate_multiplier / WWt.sum(axis=0).max()
for j in range(inner_iters + 1):
losses.append(nonneg_projected_gradient_step(
loss_H, H, H_, learning_rate
))
# === UPDATE W === #
# Cached matrix products. Note that these
# don't require gradient tracking.
H_.copy_(H.data)
HHt = H_ @ H_.T
XHt = X @ H_.T
learning_rate = learning_rate_multiplier / WWt.sum(axis=0).max()
for j in range(inner_iters + 1):
losses.append(nonneg_projected_gradient_step(
loss_W, W, W_, learning_rate
))
# Print final loss.
print(f"Converged to within {100 *(torch.norm(X - W.T @ H) / torch.norm(X)).item():0.2f}% of the solution")
# Plot result.
fig, ax = plt.subplots(1, 1)
ax.plot(losses)
ax.set_ylabel("NMF loss")
plt.show()
import torch
TOL = 1e-6
def nonneg_projected_gradient_step(
evaluate_loss, param, init_param, learning_rate, maxsteps=10, shrink_rate=0.5
):
# Evaluate initial loss, compute gradient.
init_loss = evaluate_loss()
init_loss.backward()
init_loss = init_loss.item()
# Take initial step.
with torch.no_grad():
init_param.copy_(param.data)
param.grad.mul_(learning_rate)
param.sub_(param.grad)
param.relu_()
for step in range(maxsteps):
# Evaluate new loss.
with torch.no_grad():
loss = evaluate_loss()
# Return if we've descended.
if ((init_loss - loss.item()) / init_loss) < TOL:
return loss.item()
# Otherwise, backtrack param
with torch.no_grad():
# Restore initial param
param.copy_(init_param)
# Shrink the descent vector held in param.grad
param.grad.mul_(shrink_rate)
# Take a projected gradient step.
param.sub_(param.grad)
param.relu_()
raise ValueError("Line search failed!")
@ahwillia
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment