Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created July 12, 2023 11:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fxmarty/3595f800150a0e54f037db9829364113 to your computer and use it in GitHub Desktop.
Save fxmarty/3595f800150a0e54f037db9829364113 to your computer and use it in GitHub Desktop.
ONNX Runtime multiprocessing for multi-GPU
import torch # just because of cudNN/CUDA dependency not met on the cluster
import time
import onnxruntime as ort
import numpy as np
import multiprocessing as mp
from multiprocessing import Queue
class PickableInferenceSession:
"""
This is a wrapper to make the current InferenceSession class pickable.
See https://github.com/microsoft/onnxruntime/issues/7846#issuecomment-850217402
"""
def __init__(self, model_path, device_id: int):
import onnxruntime as ort
self.model_path = model_path
self.device_id = device_id
self.providers = ["CUDAExecutionProvider"]
# self.sess = ort.InferenceSession(self.model_path, providers=self.providers, provider_options=[{"device_id": device_id}])
def run(self, *args):
return self.sess.run(*args)
def __getstate__(self):
return {"model_path": self.model_path, "providers": self.providers, "device_id": self.device_id}
def __setstate__(self, values):
import onnxruntime as ort
self.model_path = values["model_path"]
self.providers = values["providers"]
self.device_id = values["device_id"]
self.sess = ort.InferenceSession(self.model_path, providers=self.providers, provider_options=[{"device_id": self.device_id}])
class RunProcess(mp.Process):
def __init__(
self,
model_path,
device_id,
input_queue,
output_queue,
):
self.session = PickableInferenceSession(model_path, device_id)
self.input_queue = input_queue
self.output_queue = output_queue
super().__init__()
def run(self):
while True:
inp = self.input_queue.get()
if inp == "stop":
return
res = self.session.run(None, inp)
self.output_queue.put(res[0].max()) # dummy output, but the idea is not to return data-intensive output
if __name__ == '__main__':
mp.set_start_method("spawn", force=True) # see https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
inp = {
"input_ids": np.random.randint(0, 5, size=(16, 1000)),
"attention_mask": np.ones((16, 1000)).astype(np.int64),
}
input_queue1 = Queue()
input_queue2 = Queue()
input_queue3 = Queue()
input_queue4 = Queue()
output_queue1 = Queue()
output_queue2 = Queue()
output_queue3 = Queue()
output_queue4 = Queue()
model_path = "/fsx/felix/optimum/gpt2_onnx/decoder_model.onnx"
p1 = RunProcess(model_path, device_id=4, input_queue=input_queue1, output_queue=output_queue1)
p2 = RunProcess(model_path, device_id=5, input_queue=input_queue2, output_queue=output_queue2)
p3 = RunProcess(model_path, device_id=6, input_queue=input_queue3, output_queue=output_queue3)
p4 = RunProcess(model_path, device_id=7, input_queue=input_queue4, output_queue=output_queue4)
p1.start()
p2.start()
p3.start()
p4.start()
time.sleep(25) # don't include load time below, this is quite hacky
start = time.time()
input_queue1.put(inp)
input_queue2.put(inp)
input_queue3.put(inp)
input_queue4.put(inp)
res1 = output_queue1.get()
res2 = output_queue2.get()
res3 = output_queue3.get()
res4 = output_queue4.get()
end = time.time()
print(f"Took {end - start} s")
input_queue1.put("stop")
input_queue2.put("stop")
input_queue3.put("stop")
input_queue4.put("stop")
p1.join()
p2.join()
p3.join()
p4.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment