Created
December 12, 2023 15:31
-
-
Save lukasbm/5c0975494c37dde6b48261fecaac0bb7 to your computer and use it in GitHub Desktop.
Re-usable setup function for PyTorch (most common performance and debugging flags)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def setup_pytorch(seed: Optional[int] = None) -> torch.device: | |
""" | |
sets up pytorch with all kinds of settings and performance optimizations | |
copied from mad-project | |
""" | |
torch.backends.cudnn.enabled = True | |
# set seeds for reproducibility | |
if seed is not None: | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.random.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# some CudNN operations/solvers are not deterministic even with a fixed seed. | |
# force usage of deterministic implementations when a seed is set. | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
torch.use_deterministic_algorithms(True) | |
else: | |
# benchmark chooses the best algorithm based on a heuristic | |
# especially useful when input sizes do not change | |
# https://discuss.pytorch.org/t/what-is-the-differenc-between-cudnn-deterministic-and-cudnn-benchmark/38054/2 | |
torch.backends.cudnn.benchmark = True | |
# set data types so that we can use tensor cores | |
# enable cuda data type (the regular 32bit float can not run on tensor cores) | |
torch.set_float32_matmul_precision("medium") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# use oneDNN graph with TorchScript for inference | |
torch.jit.enable_onednn_fusion(True) | |
cuda_avail = torch.cuda.is_available() | |
print("backend info ====") | |
print("cuda avail:", cuda_avail) | |
print("cuda bfloat16 avail:", cuda_avail and torch.cuda.is_bf16_supported()) | |
print("cuda is initialized:", cuda_avail and torch.cuda.is_initialized()) | |
print("torch.backends.cudnn enabled:", torch.backends.cudnn.enabled) | |
print("========") | |
return torch.device("cuda" if cuda_avail else "cpu") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment