Skip to content

Instantly share code, notes, and snippets.

@vene
Created June 22, 2020 16:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vene/fb7726af1ed658d90de396f66cb513e4 to your computer and use it in GitHub Desktop.
Save vene/fb7726af1ed658d90de396f66cb513e4 to your computer and use it in GitHub Desktop.
check mc perturbed gradients (Berthet et at)
# https://arxiv.org/abs/2002.08676
# code by vlad niculae
# license: mit
import numpy as np
import matplotlib.pyplot as plt
def main():
dim = 4
n_thetas = 2
n_directions = 3
n_samples = 10_000
n_samples_almost_exact = 1_000_000
def y_star_eps(theta):
Z = np.random.RandomState(42).randn(n_samples_almost_exact, dim)
ix = np.argmax(theta + Z, axis=1)
counts = np.zeros_like(theta)
uniq, counts_sparse = np.unique(ix, return_counts=True)
counts[uniq] = counts_sparse
return counts / n_samples_almost_exact
def jvp_f1(theta, d):
Z = np.random.RandomState(55).randn(n_samples, dim)
ix = np.argmax(theta + Z, axis=1)
Y = np.zeros((n_samples, dim))
Y[np.arange(n_samples), ix] = 1
# YZt * d
YZtd = (Z * d).sum(axis=1)[:, np.newaxis] * Y
denom = np.arange(1, n_samples + 1)
approx = np.cumsum(YZtd, axis=0) / denom[:, np.newaxis]
return denom, approx
def jvp_f1_symm(theta, d):
Z = np.random.RandomState(55).randn(n_samples, dim)
ix = np.argmax(theta + Z, axis=1)
Y = np.zeros((n_samples, dim))
Y[np.arange(n_samples), ix] = 1
YZtd = (Z * d).sum(axis=1)[:, np.newaxis] * Y
ZYtd = (Y * d).sum(axis=1)[:, np.newaxis] * Z
sym = .5 * (YZtd + ZYtd)
denom = np.arange(1, n_samples + 1)
approx = np.cumsum(sym, axis=0) / denom[:, np.newaxis]
return denom, approx
def jvp_f2(theta, d):
Z = np.random.RandomState(55).randn(n_samples, dim)
F = np.max(theta + Z, axis=1)
# (ZZt - I)d = (z.t d) z - d
Ztd = (Z * d).sum(axis=1)[:, np.newaxis]
approx = F[:, np.newaxis] * (Ztd * Z - d)
denom = np.arange(1, n_samples + 1)
approx = np.cumsum(approx, axis=0) / denom[:, np.newaxis]
return denom, approx
rng = np.random.RandomState(1312)
theta = .1 * rng.randn(dim)
# pick spherically-uniform gradient directions
dirs = rng.randn(n_directions, dim)
dirs /= np.linalg.norm(dirs, axis=1)[:, np.newaxis]
fd_eps = 1e-4
fig, axes_dir = plt.subplots(n_directions, dim, sharey=True,
figsize=(6, 4),
constrained_layout=True)
for axes_d, d in zip(axes_dir, dirs):
jvp = ((y_star_eps(theta + fd_eps * d)
- y_star_eps(theta - fd_eps * d))
/ (2 * fd_eps))
print(jvp)
# n_samples, jvp_f1 = approx_jvp_formula_1(theta, d)
ix, approx_jvp_1 = jvp_f1(theta, d)
err_1 = np.abs(approx_jvp_1 - jvp)
_, approx_jvp_1_symm = jvp_f1_symm(theta, d)
err_1_symm = np.abs(approx_jvp_1_symm - jvp)
_, approx_jvp_2 = jvp_f2(theta, d)
err_2 = np.abs(approx_jvp_2 - jvp)
for j in range(dim):
axes_d[j].plot(ix, err_1[:, j], label="(1)")
axes_d[j].plot(ix, err_1_symm[:, j], label="(1) symm")
axes_d[j].plot(ix, err_2[:, j], label="(2)")
axes_d[j].semilogy()
for i in range(n_directions):
axes_dir[i][0].set_ylim((1e-3, 1))
axes_dir[i][0].set_ylabel(
"$|(\\tilde{{J}}v_{i})_d - (Jv_{i})_d|$".format(i=i))
for j in range(dim):
axes_dir[0][j].set_title("$d={}$".format(j))
axes_dir[-1][j].set_xlabel("n samples")
axes_dir[0][-1].legend(bbox_to_anchor=(1.05, 1.05))
plt.savefig("fig.pdf")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment