Skip to content

Instantly share code, notes, and snippets.

@louity
Last active April 19, 2022 15:36
Show Gist options
  • Save louity/cf12dec399c8a8ac66d616da74904fe3 to your computer and use it in GitHub Desktop.
Save louity/cf12dec399c8a8ac66d616da74904fe3 to your computer and use it in GitHub Desktop.
Task Driven Dictionary Learning : backprop and analytic gradients comparison
# Louis THIRY, 4.11.2019
# reference paper for Task Driven Dictionary Learning : https://www.di.ens.fr/~fbach/taskdriven_mairal2012.pdf
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
def elastic_net_loss(inputs, dictionary, alpha, lambda_1, lambda_2):
return (0.5 * ((torch.mm(alpha, dictionary.t()) - inputs)**2).sum(dim=1) +
lambda_1 * torch.norm(alpha, p=1, dim=1) +
0.5 * lambda_2 * (alpha**2).sum(dim=1))
def solve_elastic_net_with_ISTA(inputs, dictionary, lambda_1, lambda_2, maxiter=1000, plot_loss_and_support_size=False):
"""
Solve elastic net problem :
min_alpha 0.5 * || x - D alpha ||_2**2 + lambda_1 ||alpha||_1 + 0.5 * lambda_2 ||alpha||_2**2
using the ISTA algorithm.
Parameters:
input: torch tensor, size (batch_size, input_dimension)
dictionary: torch Tensor, size (input_dimension, n_atoms)
dictionary matrix
lambda_1: float
regulariation parameter in front of the l1 norm
lambda_2: float
regulariation parameter in front of the l2 square norm
maxiter: int, default 1000
maxmum number of iterations of the ISTA algorithm
Returns:
alpha: torch Tensor, size (batch_size, dict_size)
the sparse code of the input batch in the dictionary D
n_iter: int
number of iterations of the FISTA algorithm
diff_mean: float
mean diff in l1 norm between the two last iterates
diff_max: float
max diff in l1 norm between the two last iterates
"""
n_atoms = dictionary.size(1)
identity = torch.eye(n_atoms, out=dictionary.new(n_atoms, n_atoms))
DtD = torch.mm(dictionary.t(), dictionary) + lambda_2 * identity
with torch.no_grad():
L = torch.symeig(DtD)[0].max().item()
if plot_loss_and_support_size:
mean_loss, max_loss, support_size = [], []
alpha = nn.functional.softshrink(1 / L * torch.mm(inputs, dictionary), lambda_1 / L)
for i_iter in range(1, maxiter):
alpha = alpha + 1 / L * (torch.mm(inputs, dictionary) - torch.mm(alpha, DtD))
alpha = nn.functional.softshrink(alpha, lambda_1 / L)
if plot_loss_and_support_size:
support_size.append((alpha > 0).sum(dim=1).float().mean())
loss = elastic_net_loss(inputs, dictionary, alpha, lambda_1, lambda_2)
mean_loss.append(loss.mean().item())
max_loss.append(loss.max().item())
if plot_loss_and_support_size:
plt.figure()
plt.yscale('log')
plt.plot(range(maxiter), mean_loss, label='mean loss')
plt.plot(range(maxiter), max_loss, label='max loss')
plt.legend()
plt.figure()
plt.plot(range(maxiter), support_size, label='support_size')
plt.legend()
plt.show()
return alpha
def solve_TDDL_regression_autograd(inputs, targets, dictionary, classifier, lambda_1, lambda_2, lr=0.1, maxiter=200, maxiter_EN=100):
input_dim, n_atoms = dictionary.size()
dictionary.requires_grad = True
classifier.requires_grad = True
loss = nn.MSELoss()
optimizer = torch.optim.SGD([dictionary, classifier], lr=lr, momentum=0)
for i_iter in range(maxiter):
alpha = solve_elastic_net_with_ISTA(inputs, dictionary, lambda_1, lambda_2, maxiter=maxiter_EN)
y = torch.mm(alpha, classifier)
output = loss(y, targets)
optimizer.zero_grad()
output.backward()
optimizer.step()
if i_iter+1 % 10 == 0:
print(" - iter {}, loss : {:.7f}".format(i_iter, output.item()))
# return the gradients to compare them
gradients = (classifier.grad.view(-1).clone(), dictionary.grad.view(-1).clone())
return gradients
def solve_TDDL_regression_analytic(inputs, targets, dictionary, classifier, lambda_1, lambda_2, lr=0.1, maxiter=200, maxiter_EN=100):
b_size, input_dim, n_atoms = inputs.size(0), dictionary.size(0), dictionary.size(1)
lambda_2_identity = (lambda_2 * torch.eye(n_atoms, out=dictionary.new(n_atoms, n_atoms))).view(1, n_atoms, n_atoms).expand(b_size, n_atoms, n_atoms)
with torch.no_grad():
for i_iter in range(maxiter):
alpha = solve_elastic_net_with_ISTA(inputs, dictionary, lambda_1, lambda_2, maxiter=maxiter_EN)
grad_classifier = 2 / alpha.size(0) * torch.mm(alpha.t(), torch.mm(alpha, classifier) - targets)
active_set = (alpha != 0).float()
active_set_dictionary = dictionary.view(1, input_dim, n_atoms).expand(b_size, input_dim, n_atoms) * active_set.view(b_size, 1, n_atoms)
active_set_DtD = torch.bmm(active_set_dictionary.transpose(1, 2), active_set_dictionary) + lambda_2_identity
active_set_DtD_inverse = torch.inverse(active_set_DtD)
grad_alpha = active_set.view(b_size, n_atoms, 1) * 2 * torch.mm(torch.mm(alpha, classifier) - targets, classifier.t()).view(b_size, n_atoms, 1)
beta = torch.bmm(active_set_DtD_inverse, grad_alpha)
expanded_dictionary = dictionary.view(1, input_dim, n_atoms).expand(b_size, input_dim, n_atoms)
grad_dictionary = (
- torch.bmm(expanded_dictionary, torch.bmm(beta, alpha.view(b_size, 1, n_atoms)))
+ torch.bmm((inputs - torch.mm(alpha, dictionary.t())).view(b_size, input_dim, 1), beta.transpose(1, 2))
).mean(dim=0)
loss = torch.nn.functional.mse_loss(torch.mm(alpha, classifier), targets)
if i_iter+1 % 10 == 0:
print(" - iter {}, loss : {:.7f}".format(i_iter, loss.item()))
classifier = classifier - lr * grad_classifier
dictionary = dictionary - lr * grad_dictionary
dictionary = dictionary / torch.norm(dictionary, dim=0, p=2, keepdim=True)
# return the gradients to compare them
gradients = (grad_classifier.view(-1).clone(), grad_dictionary.view(-1).clone())
return gradients
if __name__ == '__main__':
torch.manual_seed(7)
input_dimension = 20
n_atoms = 40
print("Defining random dictionary with {} atoms in dimension {}".format(n_atoms, input_dimension))
D = torch.cuda.FloatTensor(input_dimension, n_atoms).normal_(mean=0, std=1)
D = D / torch.norm(D, dim=0, p=2, keepdim=True)
signal_support_size = 3
n_samples = 4 * n_atoms
print("n samples {}".format(n_samples))
samples_supports = torch.randint(0, n_atoms, (n_samples, signal_support_size))
samples = D.t()[samples_supports].sum(dim=1)
print("samples shape {}".format(samples.shape))
lambda_1 = 0.1
lambda_2 = 1e-2
print("Elastic Net Parameters: lam_1 {}, lam_2 {}".format(lambda_1, lambda_2))
classifier = torch.cuda.FloatTensor(n_atoms, 1).fill_(0)
classifier[:n_atoms//2,:] = 1
classifier[n_atoms//2:,:] = -1
classifier /= classifier.view(-1).norm(p=2)
random_classifier = torch.zeros_like(classifier).uniform_(-1, 1)
random_dictionary = torch.zeros_like(D).uniform_(-1, 1)
random_dictionary = random_dictionary / torch.norm(random_dictionary, dim=0, p=2, keepdim=True)
samples_indices = torch.randperm(n_samples)[:n_atoms]
random_signal_dictionary = samples[samples_indices].t()
random_signal_dictionary /= torch.norm(random_signal_dictionary, dim=0, p=2, keepdim=True)
n_iter_EN_list = [50, 100, 1000, 10000]
n_iter_EN_max = max(n_iter_EN_list)
alpha_samples = solve_elastic_net_with_ISTA(samples, D, lambda_1, lambda_2, maxiter=n_iter_EN_max)
y_samples = torch.mm(alpha_samples, classifier)
print("target values computed with {} iterations Elastic Net".format(n_iter_EN_max))
print("")
# computing reference gradients
_, true_grad_D_randW_trueD = solve_TDDL_regression_analytic(samples, y_samples, D.clone(), random_classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN_max)
_, true_grad_D_trueW_randD = solve_TDDL_regression_analytic(samples, y_samples, random_dictionary.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN_max)
_, true_grad_D_trueW_randsignalD = solve_TDDL_regression_analytic(samples, y_samples, random_signal_dictionary.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN_max)
_, true_grad_D_randW_randD = solve_TDDL_regression_analytic(samples, y_samples, random_dictionary.clone(), random_classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN_max)
for n_iter_EN in n_iter_EN_list:
print("n iterations Elastic Net : {}".format(n_iter_EN))
# compare gradients
print("With W = W_true, D = D_true")
grad_W_analytic, grad_D_analytic = solve_TDDL_regression_analytic(samples, y_samples, D.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN)
grad_W_autograd, grad_D_autograd = solve_TDDL_regression_autograd(samples, y_samples, D.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN, lr=0.1)
print(" - analytic gradients norms W {}, D {}".format(grad_W_analytic.norm().item(), grad_D_analytic.norm().item()))
print(" - autograd gradients norms W {}, D {}".format(grad_W_autograd.norm().item(), grad_D_autograd.norm().item()))
print("Random W, D = D_true")
grad_W_analytic, grad_D_analytic = solve_TDDL_regression_analytic(samples, y_samples, D.clone(), random_classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN)
grad_W_autograd, grad_D_autograd = solve_TDDL_regression_autograd(samples, y_samples, D.clone(), random_classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN, lr=0.1)
grad_W_diff = (grad_W_analytic - grad_W_autograd).norm() * 2 / (grad_W_analytic.norm() + grad_W_autograd.norm())
grad_D_diff = (grad_D_analytic - grad_D_autograd).norm() * 2 / (grad_D_analytic.norm() + grad_D_autograd.norm())
grad_D_analytic_true_grad_diff, grad_D_autograd_true_grad_diff = (grad_D_analytic - true_grad_D_randW_trueD).norm(), (grad_D_autograd - true_grad_D_randW_trueD).norm()
print(f" - grad_W diff {grad_W_diff}")
print(f" - grad_D diff {grad_D_diff}, |grad_D_analytic - true_grad|={grad_D_analytic_true_grad_diff:.3f}, |grad_D_autograd - true_grad|={grad_D_autograd_true_grad_diff:.3f}")
print("W = W_true, Random D")
grad_W_analytic, grad_D_analytic = solve_TDDL_regression_analytic(samples, y_samples, random_dictionary.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN)
grad_W_autograd, grad_D_autograd = solve_TDDL_regression_autograd(samples, y_samples, random_dictionary.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN, lr=0.1)
grad_W_diff = (grad_W_analytic - grad_W_autograd).norm() * 2 / (grad_W_analytic.norm() + grad_W_autograd.norm())
grad_D_diff = (grad_D_analytic - grad_D_autograd).norm() * 2 / (grad_D_analytic.norm() + grad_D_autograd.norm())
grad_D_analytic_true_grad_diff, grad_D_autograd_true_grad_diff = (grad_D_analytic - true_grad_D_trueW_randD).norm(), (grad_D_autograd - true_grad_D_trueW_randD).norm()
print(f" - grad_W diff {grad_W_diff}")
print(f" - grad_D diff {grad_D_diff}, |grad_D_analytic - true_grad|={grad_D_analytic_true_grad_diff:.3f}, |grad_D_autograd - true_grad|={grad_D_autograd_true_grad_diff:.3f}")
print("W = W_true, D set of randomly selected samples")
grad_W_analytic, grad_D_analytic = solve_TDDL_regression_analytic(samples, y_samples, random_signal_dictionary.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN)
grad_W_autograd, grad_D_autograd = solve_TDDL_regression_autograd(samples, y_samples, random_signal_dictionary.clone(), classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN, lr=0.1)
grad_W_diff = (grad_W_analytic - grad_W_autograd).norm() * 2 / (grad_W_analytic.norm() + grad_W_autograd.norm())
grad_D_diff = (grad_D_analytic - grad_D_autograd).norm() * 2 / (grad_D_analytic.norm() + grad_D_autograd.norm())
grad_D_analytic_true_grad_diff, grad_D_autograd_true_grad_diff = (grad_D_analytic - true_grad_D_trueW_randsignalD).norm(), (grad_D_autograd - true_grad_D_trueW_randsignalD).norm()
print(f" - grad_W diff {grad_W_diff}")
print(f" - grad_D diff {grad_D_diff}, |grad_D_analytic - true_grad|={grad_D_analytic_true_grad_diff:.3f}, |grad_D_autograd - true_grad|={grad_D_autograd_true_grad_diff:.3f}")
print("Random W, random D")
grad_W_analytic, grad_D_analytic = solve_TDDL_regression_analytic(samples, y_samples, random_dictionary.clone(), random_classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN)
grad_W_autograd, grad_D_autograd = solve_TDDL_regression_autograd(samples, y_samples, random_dictionary.clone(), random_classifier.clone(), lambda_1, lambda_2, maxiter_EN=n_iter_EN, lr=0.1)
grad_W_diff = (grad_W_analytic - grad_W_autograd).norm() * 2 / (grad_W_analytic.norm() + grad_W_autograd.norm())
grad_D_diff = (grad_D_analytic - grad_D_autograd).norm() * 2 / (grad_D_analytic.norm() + grad_D_autograd.norm())
grad_D_analytic_true_grad_diff, grad_D_autograd_true_grad_diff = (grad_D_analytic - true_grad_D_randW_randD).norm(), (grad_D_autograd - true_grad_D_randW_randD).norm()
print(f" - grad_W diff {grad_W_diff}")
print(f" - grad_D diff {grad_D_diff}, |grad_D_analytic - true_grad|={grad_D_analytic_true_grad_diff:.3f}, |grad_D_autograd - true_grad|={grad_D_autograd_true_grad_diff:.3f}")
print('------------')
print('')
@louity
Copy link
Author

louity commented Nov 4, 2019

Output:

Defining random dictionary with 40 atoms in dimension 20
n samples 160
samples shape torch.Size([160, 20])
Elastic Net Parameters: lam_1 0.1, lam_2 0.01
target values computed with 10000 iterations Elastic Net

n iterations Elastic Net : 50
With W = W_true, D = D_true
 - analytic gradients norms W 0.007564618717879057, D 0.009144509211182594
 - autograd gradients norms W 0.007564618717879057, D 0.004748645704239607
Random W, D = D_true
 - grad_W diff 1.191091669738853e-07
 - grad_D diff 0.44299694895744324, |grad_D_analytic - true_grad|=0.358, |grad_D_autograd - true_grad|=0.301
W = W_true, Random D
 - grad_W diff 1.664014774860334e-07
 - grad_D diff 1.328352689743042, |grad_D_analytic - true_grad|=0.396, |grad_D_autograd - true_grad|=0.153
W = W_true, D set of randomly selected samples
 - grad_W diff 1.1243663777804613e-07
 - grad_D diff 1.4515023231506348, |grad_D_analytic - true_grad|=0.520, |grad_D_autograd - true_grad|=0.115
Random W,  random D
 - grad_W diff 1.1629506246890742e-07
 - grad_D diff 1.2500759363174438, |grad_D_analytic - true_grad|=5.378, |grad_D_autograd - true_grad|=1.690
------------

n iterations Elastic Net : 100
With W = W_true, D = D_true
 - analytic gradients norms W 0.0021897191181778908, D 0.002045433735474944
 - autograd gradients norms W 0.0021897191181778908, D 0.001650303485803306
Random W, D = D_true
 - grad_W diff 9.260929800802842e-08
 - grad_D diff 0.18540027737617493, |grad_D_analytic - true_grad|=0.139, |grad_D_autograd - true_grad|=0.151
W = W_true, Random D
 - grad_W diff 1.532720972363677e-07
 - grad_D diff 0.9258350133895874, |grad_D_analytic - true_grad|=0.203, |grad_D_autograd - true_grad|=0.128
W = W_true, D set of randomly selected samples
 - grad_W diff 9.593397010121407e-08
 - grad_D diff 1.0290964841842651, |grad_D_analytic - true_grad|=0.261, |grad_D_autograd - true_grad|=0.094
Random W,  random D
 - grad_W diff 1.1782324804698874e-07
 - grad_D diff 0.8438147306442261, |grad_D_analytic - true_grad|=3.089, |grad_D_autograd - true_grad|=1.553
------------

n iterations Elastic Net : 1000
With W = W_true, D = D_true
 - analytic gradients norms W 2.0597740331140812e-06, D 6.570136065420229e-06
 - autograd gradients norms W 2.0597740331140812e-06, D 6.523570391436806e-06
Random W, D = D_true
 - grad_W diff 5.962872506870553e-08
 - grad_D diff 0.0010793383698910475, |grad_D_analytic - true_grad|=0.000, |grad_D_autograd - true_grad|=0.001
W = W_true, Random D
 - grad_W diff 1.3392539699452755e-07
 - grad_D diff 0.07532138377428055, |grad_D_analytic - true_grad|=0.065, |grad_D_autograd - true_grad|=0.054
W = W_true, D set of randomly selected samples
 - grad_W diff 9.572848824745961e-08
 - grad_D diff 0.10772456228733063, |grad_D_analytic - true_grad|=0.021, |grad_D_autograd - true_grad|=0.024
Random W,  random D
 - grad_W diff 1.3093085726723075e-07
 - grad_D diff 0.0581936240196228, |grad_D_analytic - true_grad|=0.343, |grad_D_autograd - true_grad|=0.312
------------

n iterations Elastic Net : 10000
With W = W_true, D = D_true
 - analytic gradients norms W 0.0, D 0.0
 - autograd gradients norms W 0.0, D 0.0
Random W, D = D_true
 - grad_W diff 7.472370811001383e-08
 - grad_D diff 1.2733049516100436e-05, |grad_D_analytic - true_grad|=0.000, |grad_D_autograd - true_grad|=0.000
W = W_true, Random D
 - grad_W diff 1.4900091116487602e-07
 - grad_D diff 5.4614101827610284e-05, |grad_D_analytic - true_grad|=0.000, |grad_D_autograd - true_grad|=0.000
W = W_true, D set of randomly selected samples
 - grad_W diff 8.771534965035244e-08
 - grad_D diff 4.7366633225465193e-05, |grad_D_analytic - true_grad|=0.000, |grad_D_autograd - true_grad|=0.000
Random W,  random D
 - grad_W diff 1.2236972679602331e-07
 - grad_D diff 5.043438795837574e-05, |grad_D_analytic - true_grad|=0.000, |grad_D_autograd - true_grad|=0.000
------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment