Created
February 23, 2022 00:28
-
-
Save andres-fr/7de831daa849996ae91179dd6cac18af to your computer and use it in GitHub Desktop.
Inverse regular sampling from a black-box function using JAX+SGD with momentum
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from jax import jit, value_and_grad | |
from jax.config import config | |
config.update("jax_debug_nans", True) | |
class RegularInv1dSampler: | |
""" | |
This regular inverse sampler deals with the following problem: given a | |
smooth function ``y=f(x)`` for ``x`` scalar and ``y`` n-dimensional, | |
retrieve ``N`` monotonically increasing values of ``x``, such that the | |
respective ``f(x)`` are evenly spaced. | |
This is achieved numerically (via gradient descent), by minimizing the | |
variance of the successive euclidean distances. | |
""" | |
@staticmethod | |
def numeric_fwprop(fn, vals, epsilon=1e-2): | |
""" | |
:returns: The pair ``(y_vals, y_grads)``. The former is simply | |
``fn(vals)``. The latter has same shape as ``y_vals`` and is an | |
approximation of the rate of change per output at each input value | |
(i.e. how much does each ``y_val`` change per unit of ``val`` | |
changed). This is approximated numerically per the fundamental | |
theorem of calculus, using ``epsilon`` as a differential. | |
We implement this because our interpolator scipy function isn't part | |
of JAX and can't be autodifferentiated. Symdiff should be also possible. | |
""" | |
y_vals = fn(vals) | |
y_grads = fn(vals + epsilon) - y_vals | |
y_grads /= epsilon | |
return y_vals, y_grads | |
@staticmethod | |
def successive_dist_loss(arr): | |
""" | |
:param arr: Array of shape ``(num_elts, num_dims)``. | |
:returns: Array of shape ``(num_elts - 1)``, where the ith entry | |
is the euclidean distance between the ``i`` and the ``i+1`` | |
input entries. | |
""" | |
diff = jnp.diff(arr, axis=0) | |
diff_l2 = (diff * diff).sum(axis=1) | |
loss = diff_l2.var() | |
return loss | |
@classmethod | |
def __call__(cls, fn, domain_range, num_samples=1000, | |
lrate=1, momentum=0.999, loss_thresh=1e-3): | |
""" | |
:param fn: The function ``y = f(x)``. Assumed to be smooth and | |
differentiable. | |
:param domain_range: A pair ``(beg, end)`` for the ``x`` range to | |
be sampled from. | |
:returns: A pair ``(xxx, yyy)``, both arrays with ``num_samples``, | |
where ``xxx`` is monotonically increasing and starts and ends with | |
the given ``domain_range``, and ``yyy`` elements are evenly spaced | |
in terms of their successive euclidean distances. | |
""" | |
x = np.linspace(*domain_range, num_samples) | |
y, x_grad = cls.numeric_fwprop(fn, x, epsilon=0.01) | |
grad_fn = jit(value_and_grad(cls.successive_dist_loss, argnums=0)) | |
loss, y_grad = grad_fn(y) | |
update = np.zeros_like(x_grad[:, 0]) | |
try: | |
while ((loss > loss_thresh)): | |
print("loss:", loss) | |
backprop = (x_grad * y_grad).sum(axis=1) | |
update = backprop + momentum * update | |
x[1:-1] -= lrate * update[1:-1] | |
x.sort() | |
x.clip(min=domain_range[0], max=domain_range[1] - 1) | |
y, x_grad = cls.numeric_fwprop(fn, x) | |
loss, y_grad = grad_fn(y) | |
# debug backprop | |
# from jax import make_jaxpr | |
# make_jaxpr(grad_fn)(y) | |
# | |
return x, y | |
except FloatingPointError as fpe: | |
print(fpe) | |
raise FloatingPointError("Try with a smaller learning rate!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment