Skip to content

Instantly share code, notes, and snippets.

@nqbao
Created August 27, 2017 17:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save nqbao/5fd9c0aaba7268483b8a84ecffccc7ed to your computer and use it in GitHub Desktop.
Save nqbao/5fd9c0aaba7268483b8a84ecffccc7ed to your computer and use it in GitHub Desktop.
gpu_allocation.py
import os
import logging
import gpustat
logger = logging.getLogger(__name__)
def get_free_gpu_indices():
stats = gpustat.GPUStatCollection.new_query()
used_gpus = set()
for p in stats.running_processes():
used_gpus.add(p['gpu_uuid'])
result = []
for i in range(len(stats)):
gpu = stats[i]
if gpu.uuid not in used_gpus:
result.append(gpu.entry['index'])
return result
def _acquire_gpus():
try:
import tensorflow as tf
g = tf.Graph()
with tf.Session(graph=g) as s:
s.run(tf.constant(1))
except Exception:
raise
def auto_acquire_gpus(num_gpus=1):
if 'CUDA_VISIBLE_DEVICES' in os.environ:
return os.environ['CUDA_VISIBLE_DEVICES']
avai_gpus = get_free_gpu_indices()
if num_gpus > len(avai_gpus):
raise Exception("Unable to acquire %d GPUs, there are only %d available." % (
num_gpus,
len(avai_gpus)
))
avai_gpus = avai_gpus[:num_gpus]
gpus = ','.join([str(i) for i in avai_gpus])
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
logger.info('Acquiring GPUs: %s' % gpus)
_acquire_gpus()
return os.environ['CUDA_VISIBLE_DEVICES']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment