Last active
June 13, 2023 05:25
-
-
Save janaSunrise/e082e2439294e359f6f7480a0100f765 to your computer and use it in GitHub Desktop.
Function to get the device to run the code on so it's device-agnostic and accelerated if any device other than CPU is available.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We don't check if `rocm` is available as it uses the same CUDA semantics for AMD GPUs | |
def get_device(): | |
if torch.cuda.is_available(): | |
return 'cuda' | |
elif torch.backends.mps.is_available(): | |
return 'mps' | |
else: | |
return 'cpu' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment