Skip to content

Instantly share code, notes, and snippets.

@Geson-anko
Created December 20, 2023 08:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Geson-anko/3b11be4cbc0e9e9056bc0f15615f57c0 to your computer and use it in GitHub Desktop.
Save Geson-anko/3b11be4cbc0e9e9056bc0f15615f57c0 to your computer and use it in GitHub Desktop.
デバイス情報を取得できる `nn.Module`
import torch
import torch.nn as nn
import torch.backends.mps
class ModuleWithDevice(nn.Module):
def __init__(self, *args, default_device=torch.device("cpu"), **kwargs) -> None:
super().__init__(*args, **kwargs)
# If this module has no parameters, returns this value.
self._default_device = default_device
@property
def device(self) -> torch.device:
for param in self.parameters():
return param.device
return self._default_device
class Encoder(ModuleWithDevice):
def __init__(self) -> None:
super().__init__()
self.layer = nn.Linear(10, 20)
if __name__ == "__main__":
encoder = Encoder()
print(encoder.device)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
encoder.to(device)
print(encoder.device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment