Skip to content

Instantly share code, notes, and snippets.

@priyathamkat
Last active February 14, 2024 18:11
Show Gist options
  • Star 20 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save priyathamkat/31d6b931095358840ba4bf073127864c to your computer and use it in GitHub Desktop.
Save priyathamkat/31d6b931095358840ba4bf073127864c to your computer and use it in GitHub Desktop.
Multi GPU inference using `torch.multiprocessing`
import torch
import torch.multiprocessing as mp
from absl import app, flags
from torchvision.models import AlexNet
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_processes", 2, "Number of subprocesses to use")
def infer(rank, queue):
"""Each subprocess will run this function on a different GPU which is indicated by the parameter `rank`."""
model = AlexNet()
device = torch.device(f"cuda:{rank}")
model.to(device)
while True:
x = queue.get()
if x is None: # check for sentinel value
break
x = x.to(device)
model(x)
del x # free memory
print(f"Inference on process {rank}")
def main(argv):
queue = mp.Queue()
processes = []
for rank in range(FLAGS.num_processes):
p = mp.Process(target=infer, args=(rank, queue))
p.start()
processes.append(p)
for _ in range(10):
queue.put(torch.randn(1, 3, 224, 224))
for _ in range(FLAGS.num_processes):
queue.put(None) # sentinel value to signal subprocesses to exit
for p in processes:
p.join() # wait for all subprocesses to finish
if __name__ == "__main__":
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment