Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Created February 23, 2022 00:28
Show Gist options
  • Save andres-fr/7de831daa849996ae91179dd6cac18af to your computer and use it in GitHub Desktop.
Save andres-fr/7de831daa849996ae91179dd6cac18af to your computer and use it in GitHub Desktop.
Inverse regular sampling from a black-box function using JAX+SGD with momentum
import jax.numpy as jnp
from jax import jit, value_and_grad
from jax.config import config
config.update("jax_debug_nans", True)
class RegularInv1dSampler:
"""
This regular inverse sampler deals with the following problem: given a
smooth function ``y=f(x)`` for ``x`` scalar and ``y`` n-dimensional,
retrieve ``N`` monotonically increasing values of ``x``, such that the
respective ``f(x)`` are evenly spaced.
This is achieved numerically (via gradient descent), by minimizing the
variance of the successive euclidean distances.
"""
@staticmethod
def numeric_fwprop(fn, vals, epsilon=1e-2):
"""
:returns: The pair ``(y_vals, y_grads)``. The former is simply
``fn(vals)``. The latter has same shape as ``y_vals`` and is an
approximation of the rate of change per output at each input value
(i.e. how much does each ``y_val`` change per unit of ``val``
changed). This is approximated numerically per the fundamental
theorem of calculus, using ``epsilon`` as a differential.
We implement this because our interpolator scipy function isn't part
of JAX and can't be autodifferentiated. Symdiff should be also possible.
"""
y_vals = fn(vals)
y_grads = fn(vals + epsilon) - y_vals
y_grads /= epsilon
return y_vals, y_grads
@staticmethod
def successive_dist_loss(arr):
"""
:param arr: Array of shape ``(num_elts, num_dims)``.
:returns: Array of shape ``(num_elts - 1)``, where the ith entry
is the euclidean distance between the ``i`` and the ``i+1``
input entries.
"""
diff = jnp.diff(arr, axis=0)
diff_l2 = (diff * diff).sum(axis=1)
loss = diff_l2.var()
return loss
@classmethod
def __call__(cls, fn, domain_range, num_samples=1000,
lrate=1, momentum=0.999, loss_thresh=1e-3):
"""
:param fn: The function ``y = f(x)``. Assumed to be smooth and
differentiable.
:param domain_range: A pair ``(beg, end)`` for the ``x`` range to
be sampled from.
:returns: A pair ``(xxx, yyy)``, both arrays with ``num_samples``,
where ``xxx`` is monotonically increasing and starts and ends with
the given ``domain_range``, and ``yyy`` elements are evenly spaced
in terms of their successive euclidean distances.
"""
x = np.linspace(*domain_range, num_samples)
y, x_grad = cls.numeric_fwprop(fn, x, epsilon=0.01)
grad_fn = jit(value_and_grad(cls.successive_dist_loss, argnums=0))
loss, y_grad = grad_fn(y)
update = np.zeros_like(x_grad[:, 0])
try:
while ((loss > loss_thresh)):
print("loss:", loss)
backprop = (x_grad * y_grad).sum(axis=1)
update = backprop + momentum * update
x[1:-1] -= lrate * update[1:-1]
x.sort()
x.clip(min=domain_range[0], max=domain_range[1] - 1)
y, x_grad = cls.numeric_fwprop(fn, x)
loss, y_grad = grad_fn(y)
# debug backprop
# from jax import make_jaxpr
# make_jaxpr(grad_fn)(y)
#
return x, y
except FloatingPointError as fpe:
print(fpe)
raise FloatingPointError("Try with a smaller learning rate!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment