Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created February 4, 2026 03:11
Show Gist options
  • Select an option

  • Save ezyang/53a2151fce613cf4cd6b04581146aebe to your computer and use it in GitHub Desktop.

Select an option

Save ezyang/53a2151fce613cf4cd6b04581146aebe to your computer and use it in GitHub Desktop.
jax.typeof(x)=ShapedArray(float32[4,8@dp,16]), jax.typeof(w1)=ShapedArray(float32[16,32@tp]), jax.typeof(w3)=ShapedArray(float32[16,32@tp]), jax.typeof(w2)=ShapedArray(float32[32@tp,16])
jax.typeof(rx)=ShapedArray(float32[4,8@dp,16]{R:tp}), jax.typeof(rw1)=ShapedArray(float32[16,32@tp]{R:dp}), jax.typeof(rw3)=ShapedArray(float32[16,32@tp]{R:dp}), jax.typeof(rw2)=ShapedArray(float32[32@tp,16]{R:dp})
jax.typeof(h1)=ShapedArray(float32[4,8@dp,32@tp])
jax.typeof(h3)=ShapedArray(float32[4,8@dp,32@tp])
jax.typeof(h)=ShapedArray(float32[4,8@dp,32@tp])
jax.typeof(out)=ShapedArray(float32[4,8@dp,16]{U:tp})
float32[4,8@dp,16]{U:tp}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment