Skip to content

Instantly share code, notes, and snippets.

@ayaka14732
Last active January 16, 2024 07:48
Show Gist options
  • Save ayaka14732/0daa4bb50e563ea556a0102a32afc33e to your computer and use it in GitHub Desktop.
Save ayaka14732/0daa4bb50e563ea556a0102a32afc33e to your computer and use it in GitHub Desktop.
Track TPU memory usage while running the training script. See https://twitter.com/ayaka14732/status/1565016471323156481 for more details.
# monitor the memory profile with `watch --color -n1 go tool pprof -tags /dev/shm/memory.prof`
import functools
import jax
import jax.numpy as np
import random
import threading
devices = jax.devices()
n_devices = jax.device_count()
def initialise_memory_tracking():
def inner():
import posix
import time
while True:
jax.profiler.save_device_memory_profile('/dev/shm/memory.prof.new')
posix.rename('/dev/shm/memory.prof.new', '/dev/shm/memory.prof') # atomic
time.sleep(1.)
thread = threading.Thread(target=inner, daemon=True)
thread.start()
@functools.partial(jax.pmap, axis_name='n_devices')
def some_heavy_computation(a, b):
c = np.einsum('abcd,ebcd->ae', a, b)
d = jax.lax.pmean(c, axis_name='n_devices')
return d
def main():
initialise_memory_tracking()
for i in range(1000):
print(i)
x = random.randrange(100, 32000)
y = random.randrange(100, 32000)
a = np.zeros((x, 11, 4, 2), dtype=np.float32)
b = np.zeros((y, 11, 4, 2), dtype=np.float32)
a = jax.device_put_replicated(a, devices=devices)
b = jax.device_put_replicated(b, devices=devices)
some_heavy_computation(a, b)
print('Done')
if __name__ == '__main__':
main()
@ayaka14732
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment