Skip to content

Instantly share code, notes, and snippets.

@oliverdutton
Created April 23, 2024 21:29
Show Gist options
  • Save oliverdutton/98c468dccfc5dcc3f0f2c0b793f46bb2 to your computer and use it in GitHub Desktop.
Save oliverdutton/98c468dccfc5dcc3f0f2c0b793f46bb2 to your computer and use it in GitHub Desktop.
Reproduce NaN behaviour in unguarded indexing
import jax
from jax import jit, numpy as jnp
from alphafold.model import model
key = jax.random.PRNGKey(42)
nrepeats = 100
for nres in range(128,256):
print(nres)
for i in range(nrepeats):
q, k, v = jax.random.uniform(key, (3, 1024, nres, 8, 32))
f = jax.jit(model.modules.Attention.flash_kernel, static_argnames=(
'return_residual', 'block_q', 'block_k', 'num_warps', 'num_stages', 'grid', 'interpret', 'debug')
)
assert jnp.isfinite(f(q,k,v)).all(), f"Failed with {nres} on run {i}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment