Skip to content

Instantly share code, notes, and snippets.

@matthewfeickert
Last active January 13, 2021 00:24
Show Gist options
  • Save matthewfeickert/b7107203549916f013235934fac5e255 to your computer and use it in GitHub Desktop.
Save matthewfeickert/b7107203549916f013235934fac5e255 to your computer and use it in GitHub Desktop.
reproducible issue with jaxlib v0.1.58

Regression in jaxlib v0.1.58 for jax.scipy.special.gammaln for CPU

In pyhf we've noticed one of our unit tests that was passing for jax v0.2.7 and jaxlib v0.1.57, however, with the release of jaxlib v0.1.58 it has started failing. We've narrowed it down to being for jaxlib v0.1.58 with jax_enable_x64=True in CPU mode (aka, where our unit test run).

Minimal Reproducible Example

In a fresh Python 3.8 virtual environment

$ python --version --version
Python 3.8.6 (default, Jan  5 2021, 00:14:15)
[GCC 9.3.0]
$ python -m pip install --quiet --upgrade pip setuptools wheel
$ python -m pip install jax jaxlib
$ python -m pip list
Package     Version
----------- -------
absl-py     0.11.0
flatbuffers 1.12
jax         0.2.8
jaxlib      0.1.58
numpy       1.19.5
opt-einsum  3.3.0
pip         20.3.3
scipy       1.6.0
setuptools  51.1.2
six         1.15.0
wheel       0.36.2

then for

# jaxlab_issue.py
import jax
import jaxlib
from jax.config import config
import jax.numpy as jnp
from jax.scipy.special import gammaln


class Poisson:
    def __init__(self, rate):
        self.rate = jnp.asarray(rate, dtype="float64")

    def log_prob(self, n):
        n = jnp.asarray(n, dtype="float64")
        return n * jnp.log(self.rate) - self.rate - gammaln(n + 1.0)


def main():
    config.update("jax_enable_x64", True)
    print(f"jax version: {jax.__version__}")
    print(f"jaxlib version: {jaxlib.__version__}")

    joint = gammaln(jnp.asarray([2.0, 3.0], dtype="float64")).tolist()
    individual = [
        *gammaln(jnp.asarray([2.0], dtype="float64")).tolist(),
        *gammaln(jnp.asarray([3.0], dtype="float64")).tolist(),
    ]
    print(f"joint:      {joint}")
    print(f"individual: {individual}")
    assert joint == individual

    # This is more akin to what we're seeing
    joint = Poisson([10.0, 10.0]).log_prob([2.0, 3.0])

    poisson_1 = Poisson([10.0]).log_prob(2.0)
    poisson_2 = Poisson([10.0]).log_prob(3.0)

    print(f"\njoint:      {joint.tolist()}")
    print(f"individual: {[*poisson_1.tolist(), *poisson_2.tolist()]}")
    assert joint.tolist() == [*poisson_1.tolist(), *poisson_2.tolist()]


if __name__ == "__main__":
    main()
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.58
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
joint:      [8.881784197001252e-16, 0.693147180559945]
individual: [8.881784197001252e-16, 0.6931471805599432]
Traceback (most recent call last):
  File "jaxlib_issue.py", line 43, in <module>
    main()
  File "jaxlib_issue.py", line 29, in main
    assert joint == individual
AssertionError

however for

$ python -m pip install --quiet --upgrade "jaxlib<0.1.58"
$ python -m pip list | grep jax
jax               0.2.8
jaxlib            0.1.57

things are passing as before

$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.57
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
joint:      [8.881784197001252e-16, 0.693147180559945]
individual: [8.881784197001252e-16, 0.693147180559945]

joint:      [-6.087976994571854, -4.884004190245918]
individual: [-6.087976994571854, -4.884004190245918]

If the CUDA enabled version of jaxlib is installed the issue is not seen

$ python -m pip install --quiet --upgrade jax jaxlib==0.1.57+cuda101 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
$ python -m pip list | grep jax
jax               0.2.8
jaxlib            0.1.57+cuda101
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.57
joint:      [8.881784197001252e-16, 0.6931471805599441]
individual: [8.881784197001252e-16, 0.6931471805599441]

joint:      [-6.087976994571852, -4.884004190245918]
individual: [-6.087976994571852, -4.884004190245918]
$ python -m pip install --quiet --upgrade jax jaxlib==0.1.58+cuda101 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
$ python -m pip list | grep jax
jax               0.2.8
jaxlib            0.1.58+cuda101
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.58
joint:      [8.881784197001252e-16, 0.6931471805599441]
individual: [8.881784197001252e-16, 0.6931471805599441]

joint:      [-6.087976994571852, -4.884004190245918]
individual: [-6.087976994571852, -4.884004190245918]

so this seems to be a CPU only issue.

We do realize that this difference is incredibly small, but as this is a change in behavior that we didn't expect we though we'd still report it even if this gets a "won't fix" label.

cc @lukasheinrich @kratsg

import jax
import jaxlib
from jax.config import config
import jax.numpy as jnp
from jax.scipy.special import gammaln
class Poisson:
def __init__(self, rate):
self.rate = jnp.asarray(rate, dtype="float64")
def log_prob(self, n):
n = jnp.asarray(n, dtype="float64")
return n * jnp.log(self.rate) - self.rate - gammaln(n + 1.0)
def main():
config.update("jax_enable_x64", True)
print(f"jax version: {jax.__version__}")
print(f"jaxlib version: {jaxlib.__version__}")
joint = gammaln(jnp.asarray([2.0, 3.0], dtype="float64")).tolist()
individual = [
*gammaln(jnp.asarray([2.0], dtype="float64")).tolist(),
*gammaln(jnp.asarray([3.0], dtype="float64")).tolist(),
]
print(f"joint: {joint}")
print(f"individual: {individual}")
assert joint == individual
# This is more akin to what we're seeing
joint = Poisson([10.0, 10.0]).log_prob([2.0, 3.0])
poisson_1 = Poisson([10.0]).log_prob(2.0)
poisson_2 = Poisson([10.0]).log_prob(3.0)
print(f"\njoint: {joint.tolist()}")
print(f"individual: {[*poisson_1.tolist(), *poisson_2.tolist()]}")
assert joint.tolist() == [*poisson_1.tolist(), *poisson_2.tolist()]
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment