Skip to content

Instantly share code, notes, and snippets.

@lukasbm
Created December 12, 2023 15:31
Show Gist options
  • Save lukasbm/5c0975494c37dde6b48261fecaac0bb7 to your computer and use it in GitHub Desktop.
Save lukasbm/5c0975494c37dde6b48261fecaac0bb7 to your computer and use it in GitHub Desktop.
Re-usable setup function for PyTorch (most common performance and debugging flags)
def setup_pytorch(seed: Optional[int] = None) -> torch.device:
"""
sets up pytorch with all kinds of settings and performance optimizations
copied from mad-project
"""
torch.backends.cudnn.enabled = True
# set seeds for reproducibility
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# some CudNN operations/solvers are not deterministic even with a fixed seed.
# force usage of deterministic implementations when a seed is set.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
else:
# benchmark chooses the best algorithm based on a heuristic
# especially useful when input sizes do not change
# https://discuss.pytorch.org/t/what-is-the-differenc-between-cudnn-deterministic-and-cudnn-benchmark/38054/2
torch.backends.cudnn.benchmark = True
# set data types so that we can use tensor cores
# enable cuda data type (the regular 32bit float can not run on tensor cores)
torch.set_float32_matmul_precision("medium")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# use oneDNN graph with TorchScript for inference
torch.jit.enable_onednn_fusion(True)
cuda_avail = torch.cuda.is_available()
print("backend info ====")
print("cuda avail:", cuda_avail)
print("cuda bfloat16 avail:", cuda_avail and torch.cuda.is_bf16_supported())
print("cuda is initialized:", cuda_avail and torch.cuda.is_initialized())
print("torch.backends.cudnn enabled:", torch.backends.cudnn.enabled)
print("========")
return torch.device("cuda" if cuda_avail else "cpu")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment