Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active March 24, 2024 05:00
Show Gist options
  • Save pszemraj/6b57610a49db9449e4784ac0614d2f6c to your computer and use it in GitHub Desktop.
Save pszemraj/6b57610a49db9449e4784ac0614d2f6c to your computer and use it in GitHub Desktop.
modern way to auto enable tf32
import torch
import logging
def check_ampere_gpu():
"""
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does.
"""
# Check if CUDA is available
if not torch.cuda.is_available():
logging.info("No GPU detected, running on CPU.")
return
try:
# Get the compute capability of the GPU
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
major, minor = capability
# Check if the GPU is Ampere or newer (compute capability >= 8.0)
if major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
gpu_name = torch.cuda.get_device_name(device)
print(
f"{gpu_name} (compute capability {major}.{minor}) supports NVIDIA Ampere or later, enabled TF32 in PyTorch."
)
else:
gpu_name = torch.cuda.get_device_name(device)
print(
f"{gpu_name} (compute capability {major}.{minor}) does not support NVIDIA Ampere or later."
)
except Exception as e:
logging.warning(f"Error occurred while checking GPU: {e}")
check_ampere_gpu()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment