Skip to content

Instantly share code, notes, and snippets.

@daskol
Created February 20, 2022 21:21
Show Gist options
  • Save daskol/b54004cc4d766a03c1b4d0c1acf2fe30 to your computer and use it in GitHub Desktop.
Save daskol/b54004cc4d766a03c1b4d0c1acf2fe30 to your computer and use it in GitHub Desktop.
Naive implementation of cross-entropy beats library one (flax/optax)
"""This script performs benchmarking default implementation of cross-entropy
(in flax/optax) and naive one in plain JAX. One can run the script with the
code below.
$ mv bench-entropy.{py,ipy}
$ ipython bench-entropy.ipy
naive: 63.6 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
optax: 67.3 µs ± 3.98 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Naive implementation is faster a bit on Nvidia V100 as well as user-end CPU.
"""
import jax
import jax.numpy as jnp
from flax.training.common_utils import onehot
from optax import softmax_cross_entropy
@jax.jit
def loss_entropy(y_true: jnp.ndarray, y_pred: jnp.ndarray):
assert y_pred.ndim == 2
ps = jax.nn.log_softmax(y_pred)
ts = jnp.take_along_axis(ps, y_true[:, None], axis=-1)
return -ts.mean()
@jax.jit
def entropy(y_true: jnp.ndarray, y_pred: jnp.ndarray):
assert y_pred.ndim == 2
ps = jax.nn.log_softmax(y_pred)
ts = jnp.take_along_axis(ps, y_true[:, None], axis=-1)
return -ts.squeeze()
# Ensure jit.
softmax_cross_entropy = jax.jit(softmax_cross_entropy)
key = jax.random.PRNGKey(42)
y_pred = jax.random.normal(key, (128, 2))
y_true = jax.random.randint(key, (128,), 0, 2)
y_true_onehot = onehot(y_true, 2)
print('naive:', end=' ')
%timeit entropy(y_true, y_pred).block_until_ready()
print('optax:', end=' ')
%timeit softmax_cross_entropy(y_pred, y_true_onehot).block_until_ready()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment