Skip to content

Instantly share code, notes, and snippets.

@wdevazelhes
Last active August 16, 2018 10:06
Show Gist options
  • Save wdevazelhes/8c9f5fdc53ed6e7bbe8d8f958351db85 to your computer and use it in GitHub Desktop.
Save wdevazelhes/8c9f5fdc53ed6e7bbe8d8f958351db85 to your computer and use it in GitHub Desktop.
Code for comparing two implementations of the gradient for MLKR
from metric_learn import MLKR
from sklearn.utils import check_random_state
import numpy as np
from losses import _loss_non_optimized, _loss_optimized
from collections import defaultdict
from sklearn.datasets import make_regression
for n_features in [5, 100]:
print('n_features={}'.format(n_features))
X, y = make_regression(n_features=n_features)
for seed in range(5):
rng = check_random_state(seed)
A = rng.randn(X.shape[0], X.shape[0])
print('gradient differences:')
print(np.linalg.norm(_loss_optimized(A, X, y)[1]
- _loss_non_optimized(A, X, y)[1]))
print('loss differences:')
print(_loss_optimized(A, X, y)[0]
- _loss_non_optimized(A, X, y)[0])
# Printing the whole values (with less features for better visualisation)
X, y = make_regression(n_features=5)
for seed in range(5):
rng = check_random_state(seed)
A = rng.randn(2, 5)
for loss in [_loss_non_optimized, _loss_optimized]:
print(loss(A, X, y))
from sklearn.datasets import make_regression
from metric_learn import MLKR
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.special import logsumexp
def _loss_optimized(flatA, X, y):
A = flatA.reshape((-1, X.shape[1]))
dist = pdist(X, metric='mahalanobis', VI=A.T.dot(A))
dist = squareform(dist ** 2)
np.fill_diagonal(dist, np.inf)
softmax = np.exp(- dist - logsumexp(- dist, axis=1)[:, np.newaxis])
yhat = softmax.dot(y)
ydiff = yhat - y
cost = (ydiff**2).sum()
# also compute the gradient
W = softmax * ydiff[:, np.newaxis] * (yhat[:, np.newaxis] - y)
X_emb_t = A.dot(X.T)
grad = (4 * (X_emb_t * (W.sum(axis=0))
- X_emb_t.dot(W + W.T)).dot(X))
return cost, grad.ravel()
def _loss_non_optimized(flatA, X, y):
dX = (X[None] - X[:, None]).reshape((-1, X.shape[1]))
A = flatA.reshape((-1, X.shape[1]))
dist = pdist(X, metric='mahalanobis', VI=A.T.dot(A))
dist = squareform(dist ** 2)
np.fill_diagonal(dist, np.inf)
softmax = np.exp(- dist - logsumexp(- dist, axis=1)[:, np.newaxis])
yhat = softmax.dot(y)
ydiff = yhat - y
cost = (ydiff**2).sum()
# also compute the gradient
W = 2 * softmax * ydiff[:, np.newaxis] * (yhat[:, np.newaxis] - y)
# note: this is the part that the matlab impl drops to C for
M = (dX.T * W.ravel()).dot(dX)
grad = 2 * A.dot(M)
return cost, grad.ravel()
n_features=5
gradient differences:
3.5345143236136213e-07
loss differences:
0.0
gradient differences:
1.4108694814314438e-07
loss differences:
0.0
gradient differences:
4.480441650559029e-07
loss differences:
0.0
gradient differences:
3.808513898809513e-07
loss differences:
0.0
gradient differences:
5.433925727007268e-07
loss differences:
0.0
n_features=100
gradient differences:
3.959822089906634e-05
loss differences:
0.0
gradient differences:
5.3507102171022624e-09
loss differences:
0.0
gradient differences:
0.00012580672385627842
loss differences:
0.0
gradient differences:
4.469019658445259e-05
loss differences:
0.0
gradient differences:
5.81210044182099e-06
loss differences:
0.0
(236393.51759114384, array([-219161.66316797, -170897.25052117, 59911.3633348 ,
186149.63871032, 5270.58498449, 53025.7184015 ,
64104.16815048, -80517.43248341, -40525.74730135,
40761.10429498]))
(236393.51759114384, array([-219161.66316797, -170897.25052117, 59911.3633348 ,
186149.63871032, 5270.58498449, 53025.7184015 ,
64104.16815048, -80517.43248341, -40525.74730135,
40761.10429498]))
(1501224.4227846859, array([-209351.22776078, -260401.58473759, 6864.50714577,
-227578.33921967, 172301.175685 , -140104.31132398,
80865.16959862, 103477.93303428, 123817.35534745,
428969.55979609]))
(1501224.4227846859, array([-209351.22776078, -260401.58473759, 6864.50714577,
-227578.33921967, 172301.175685 , -140104.31132398,
80865.16959862, 103477.93303428, 123817.35534746,
428969.55979609]))
(795249.879923564, array([ 46840.66296907, -75064.23955227, -31242.95390164,
7520.43747259, -27340.04496518, 671234.20586862,
456469.43030066, -241225.58944759, -17310.93917806,
113111.51540722]))
(795249.879923564, array([ 46840.66296907, -75064.23955227, -31242.95390164,
7520.43747259, -27340.04496518, 671234.20586862,
456469.43030066, -241225.58944759, -17310.93917806,
113111.51540722]))
(924194.0322471606, array([ 205955.58356486, -34547.5281954 , -65518.95190382,
162702.50449867, 91425.93225032, 801867.80294049,
158816.43702189, -397855.89401635, 780588.87373477,
326040.71320266]))
(924194.0322471606, array([ 205955.58356486, -34547.5281954 , -65518.95190383,
162702.50449868, 91425.93225032, 801867.80294049,
158816.43702189, -397855.89401635, 780588.87373476,
326040.71320266]))
(814925.4406213816, array([ -5983.30201068, -35876.07751133, -37340.34433632, 96135.50074118,
9191.34576283, 142508.76621226, -62013.01269183, 377700.75180631,
691628.55589295, 114744.10319056]))
(814925.4406213816, array([ -5983.30201068, -35876.07751133, -37340.34433632, 96135.50074118,
9191.34576283, 142508.76621226, -62013.01269183, 377700.75180631,
691628.55589295, 114744.10319056]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment