Skip to content

Instantly share code, notes, and snippets.

@clemisch
Created January 8, 2020 11:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save clemisch/f1462d08591eceeb377b31de4eaa2b9e to your computer and use it in GitHub Desktop.
Save clemisch/f1462d08591eceeb377b31de4eaa2b9e to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as np
import numpy as onp
def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
"""Convenience wrapper around slice applying to only one dimension."""
start_indices = [0] * operand.ndim
limit_indices = list(operand.shape)
strides = [1] * operand.ndim
# translate `None`
len_axis = operand.shape[axis]
start_index = start_index if start_index is not None else 0
limit_index = limit_index if limit_index is not None else len_axis
# translate negative indices
if start_index < 0:
start_index = start_index + len_axis
if limit_index < 0:
limit_index = limit_index + len_axis
axis = int(axis)
start_indices[axis] = int(start_index)
limit_indices[axis] = int(limit_index)
strides[axis] = int(stride)
return jax.lax.slice(operand, start_indices, limit_indices, strides)
@jax.partial(jax.jit, static_argnums=1)
def gradient_along_axis_swapaxes(a, axis):
a_swap = np.swapaxes(a, 0, axis)
a_grad = np.concatenate((
(a_swap[1] - a_swap[0])[np.newaxis],
(a_swap[2:] - a_swap[:-2]) * 0.5,
(a_swap[-1] - a_swap[-2])[np.newaxis]
), axis=0)
return np.swapaxes(a_grad, 0, axis)
@jax.partial(jax.jit, static_argnums=1)
def gradient_along_axis_sliced(a, axis):
sliced = jax.partial(slice_in_dim, a, axis=axis)
a_grad = np.concatenate((
sliced(1, 2) - sliced(0, 1),
(sliced(2, None) - sliced(0, -2)) * 0.5,
sliced(-1, None) - sliced(-2, -1),
), axis)
return a_grad
@jax.jit
def gradient_swap(a):
a_grad = [gradient_along_axis_swapaxes(a, ax) for ax in range(a.ndim)]
return a_grad
@jax.jit
def gradient_sliced(a):
a_grad = [gradient_along_axis_sliced(a, ax) for ax in range(a.ndim)]
return a_grad
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (100, 100, 100))
onp.testing.assert_allclose(onp.gradient(x), gradient_swap(x))
onp.testing.assert_allclose(onp.gradient(x), gradient_sliced(x))
%timeit jax.device_get(gradient_swap(x))
%timeit jax.device_get(gradient_sliced(x))
# CPU (i7 8550U)
# 14.3 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8.89 ms ± 274 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# GPU (GTX 1080 Ti)
# 6.35 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 4.69 ms ± 54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment