Skip to content

Instantly share code, notes, and snippets.

@Oil3
Created December 9, 2023 00:07
Show Gist options
  • Save Oil3/92a93d36f00cc76a8b18c99783bc7c9c to your computer and use it in GitHub Desktop.
Save Oil3/92a93d36f00cc76a8b18c99783bc7c9c to your computer and use it in GitHub Desktop.
mps pytorch from some codeformer
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def gpu_is_available():
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return True
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
def get_device(gpu_id=None):
if gpu_id is None:
gpu_str = ''
elif isinstance(gpu_id, int):
gpu_str = f':{gpu_id}'
else:
raise TypeError('Input should be int value.')
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return torch.device('mps'+gpu_str)
return torch.device('cuda'+gpu_str if torch.c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment