Twitter thread: https://twitter.com/theshawwn/status/1456925974919004165
Hacker News thread: https://news.ycombinator.com/item?id=29128998
November 6, 2021
How does JAX allocate memory on a TPU?
jnp.device_put(1) is deceptively simple to write in JAX. But on a TPU, what actually happens? How does a tensor containing the value
1 actually get onto a TPU?
Turns out, the answer is "C++", and a lot of it.