Created
December 23, 2023 20:31
-
-
Save fakerybakery/3d10d87f37d2dbbdd703f4181cd5dfa2 to your computer and use it in GitHub Desktop.
Support many GPUs on PyTorch!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Unlicense, attribution optional but appreciated | |
# Copyright: Copyright (c) 2023 mrfakename. All rights reserved. Distributed under the Unlicense license. | |
# Author: mrfakename | |
# Date created: Dec 23, 2023 | |
# Published on: GitHub Gist | |
import torch | |
def get_device(): | |
if torch.cuda.is_available(): | |
return 'cuda' | |
if torch.backends.mps.is_available(): | |
return 'mps' | |
return 'cpu' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Unlicense, attribution optional but appreciated | |
# Copyright: Copyright (c) 2023 mrfakename. All rights reserved. Distributed under the Unlicense license. | |
# Author: mrfakename | |
# Date created: Dec 23, 2023 | |
# Published on: GitHub Gist | |
from device_picker import get_device | |
# ... # | |
model.to(get_device()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment