Created
June 22, 2020 16:47
-
-
Save vene/fb7726af1ed658d90de396f66cb513e4 to your computer and use it in GitHub Desktop.
check mc perturbed gradients (Berthet et at)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://arxiv.org/abs/2002.08676 | |
# code by vlad niculae | |
# license: mit | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def main(): | |
dim = 4 | |
n_thetas = 2 | |
n_directions = 3 | |
n_samples = 10_000 | |
n_samples_almost_exact = 1_000_000 | |
def y_star_eps(theta): | |
Z = np.random.RandomState(42).randn(n_samples_almost_exact, dim) | |
ix = np.argmax(theta + Z, axis=1) | |
counts = np.zeros_like(theta) | |
uniq, counts_sparse = np.unique(ix, return_counts=True) | |
counts[uniq] = counts_sparse | |
return counts / n_samples_almost_exact | |
def jvp_f1(theta, d): | |
Z = np.random.RandomState(55).randn(n_samples, dim) | |
ix = np.argmax(theta + Z, axis=1) | |
Y = np.zeros((n_samples, dim)) | |
Y[np.arange(n_samples), ix] = 1 | |
# YZt * d | |
YZtd = (Z * d).sum(axis=1)[:, np.newaxis] * Y | |
denom = np.arange(1, n_samples + 1) | |
approx = np.cumsum(YZtd, axis=0) / denom[:, np.newaxis] | |
return denom, approx | |
def jvp_f1_symm(theta, d): | |
Z = np.random.RandomState(55).randn(n_samples, dim) | |
ix = np.argmax(theta + Z, axis=1) | |
Y = np.zeros((n_samples, dim)) | |
Y[np.arange(n_samples), ix] = 1 | |
YZtd = (Z * d).sum(axis=1)[:, np.newaxis] * Y | |
ZYtd = (Y * d).sum(axis=1)[:, np.newaxis] * Z | |
sym = .5 * (YZtd + ZYtd) | |
denom = np.arange(1, n_samples + 1) | |
approx = np.cumsum(sym, axis=0) / denom[:, np.newaxis] | |
return denom, approx | |
def jvp_f2(theta, d): | |
Z = np.random.RandomState(55).randn(n_samples, dim) | |
F = np.max(theta + Z, axis=1) | |
# (ZZt - I)d = (z.t d) z - d | |
Ztd = (Z * d).sum(axis=1)[:, np.newaxis] | |
approx = F[:, np.newaxis] * (Ztd * Z - d) | |
denom = np.arange(1, n_samples + 1) | |
approx = np.cumsum(approx, axis=0) / denom[:, np.newaxis] | |
return denom, approx | |
rng = np.random.RandomState(1312) | |
theta = .1 * rng.randn(dim) | |
# pick spherically-uniform gradient directions | |
dirs = rng.randn(n_directions, dim) | |
dirs /= np.linalg.norm(dirs, axis=1)[:, np.newaxis] | |
fd_eps = 1e-4 | |
fig, axes_dir = plt.subplots(n_directions, dim, sharey=True, | |
figsize=(6, 4), | |
constrained_layout=True) | |
for axes_d, d in zip(axes_dir, dirs): | |
jvp = ((y_star_eps(theta + fd_eps * d) | |
- y_star_eps(theta - fd_eps * d)) | |
/ (2 * fd_eps)) | |
print(jvp) | |
# n_samples, jvp_f1 = approx_jvp_formula_1(theta, d) | |
ix, approx_jvp_1 = jvp_f1(theta, d) | |
err_1 = np.abs(approx_jvp_1 - jvp) | |
_, approx_jvp_1_symm = jvp_f1_symm(theta, d) | |
err_1_symm = np.abs(approx_jvp_1_symm - jvp) | |
_, approx_jvp_2 = jvp_f2(theta, d) | |
err_2 = np.abs(approx_jvp_2 - jvp) | |
for j in range(dim): | |
axes_d[j].plot(ix, err_1[:, j], label="(1)") | |
axes_d[j].plot(ix, err_1_symm[:, j], label="(1) symm") | |
axes_d[j].plot(ix, err_2[:, j], label="(2)") | |
axes_d[j].semilogy() | |
for i in range(n_directions): | |
axes_dir[i][0].set_ylim((1e-3, 1)) | |
axes_dir[i][0].set_ylabel( | |
"$|(\\tilde{{J}}v_{i})_d - (Jv_{i})_d|$".format(i=i)) | |
for j in range(dim): | |
axes_dir[0][j].set_title("$d={}$".format(j)) | |
axes_dir[-1][j].set_xlabel("n samples") | |
axes_dir[0][-1].legend(bbox_to_anchor=(1.05, 1.05)) | |
plt.savefig("fig.pdf") | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment