In pyhf
we've noticed one of our unit tests that was passing for jax
v0.2.7
and jaxlib
v0.1.57
, however, with the release of jaxlib
v0.1.58
it has started failing.
We've narrowed it down to being for jaxlib
v0.1.58
with jax_enable_x64=True
in CPU mode (aka, where our unit test run).
In a fresh Python 3.8 virtual environment
$ python --version --version
Python 3.8.6 (default, Jan 5 2021, 00:14:15)
[GCC 9.3.0]
$ python -m pip install --quiet --upgrade pip setuptools wheel
$ python -m pip install jax jaxlib
$ python -m pip list
Package Version
----------- -------
absl-py 0.11.0
flatbuffers 1.12
jax 0.2.8
jaxlib 0.1.58
numpy 1.19.5
opt-einsum 3.3.0
pip 20.3.3
scipy 1.6.0
setuptools 51.1.2
six 1.15.0
wheel 0.36.2
then for
# jaxlab_issue.py
import jax
import jaxlib
from jax.config import config
import jax.numpy as jnp
from jax.scipy.special import gammaln
class Poisson:
def __init__(self, rate):
self.rate = jnp.asarray(rate, dtype="float64")
def log_prob(self, n):
n = jnp.asarray(n, dtype="float64")
return n * jnp.log(self.rate) - self.rate - gammaln(n + 1.0)
def main():
config.update("jax_enable_x64", True)
print(f"jax version: {jax.__version__}")
print(f"jaxlib version: {jaxlib.__version__}")
joint = gammaln(jnp.asarray([2.0, 3.0], dtype="float64")).tolist()
individual = [
*gammaln(jnp.asarray([2.0], dtype="float64")).tolist(),
*gammaln(jnp.asarray([3.0], dtype="float64")).tolist(),
]
print(f"joint: {joint}")
print(f"individual: {individual}")
assert joint == individual
# This is more akin to what we're seeing
joint = Poisson([10.0, 10.0]).log_prob([2.0, 3.0])
poisson_1 = Poisson([10.0]).log_prob(2.0)
poisson_2 = Poisson([10.0]).log_prob(3.0)
print(f"\njoint: {joint.tolist()}")
print(f"individual: {[*poisson_1.tolist(), *poisson_2.tolist()]}")
assert joint.tolist() == [*poisson_1.tolist(), *poisson_2.tolist()]
if __name__ == "__main__":
main()
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.58
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
joint: [8.881784197001252e-16, 0.693147180559945]
individual: [8.881784197001252e-16, 0.6931471805599432]
Traceback (most recent call last):
File "jaxlib_issue.py", line 43, in <module>
main()
File "jaxlib_issue.py", line 29, in main
assert joint == individual
AssertionError
however for
$ python -m pip install --quiet --upgrade "jaxlib<0.1.58"
$ python -m pip list | grep jax
jax 0.2.8
jaxlib 0.1.57
things are passing as before
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.57
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
joint: [8.881784197001252e-16, 0.693147180559945]
individual: [8.881784197001252e-16, 0.693147180559945]
joint: [-6.087976994571854, -4.884004190245918]
individual: [-6.087976994571854, -4.884004190245918]
If the CUDA enabled version of jaxlib
is installed the issue is not seen
$ python -m pip install --quiet --upgrade jax jaxlib==0.1.57+cuda101 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
$ python -m pip list | grep jax
jax 0.2.8
jaxlib 0.1.57+cuda101
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.57
joint: [8.881784197001252e-16, 0.6931471805599441]
individual: [8.881784197001252e-16, 0.6931471805599441]
joint: [-6.087976994571852, -4.884004190245918]
individual: [-6.087976994571852, -4.884004190245918]
$ python -m pip install --quiet --upgrade jax jaxlib==0.1.58+cuda101 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
$ python -m pip list | grep jax
jax 0.2.8
jaxlib 0.1.58+cuda101
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.58
joint: [8.881784197001252e-16, 0.6931471805599441]
individual: [8.881784197001252e-16, 0.6931471805599441]
joint: [-6.087976994571852, -4.884004190245918]
individual: [-6.087976994571852, -4.884004190245918]
so this seems to be a CPU only issue.
We do realize that this difference is incredibly small, but as this is a change in behavior that we didn't expect we though we'd still report it even if this gets a "won't fix" label.
cc @lukasheinrich @kratsg