Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active November 20, 2021 03:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/2328b13bdb11c6309ba449195a6b551a to your computer and use it in GitHub Desktop.
Save xmodar/2328b13bdb11c6309ba449195a6b551a to your computer and use it in GitHub Desktop.
A more flexible context manager than `torch.random.fork_rng()` to preserve the state of the random number generator in PyTorch for the desired devices.
import torch
class RNG():
"""Preserve the state of the random number generators of torch
https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a
Inspired by torch.random.fork_rng().
Seeding random number generators (RNGs):
- (PyTorch) torch.manual_seed(seed)
- (Numpy) numpy.random.seed(seed)
- (Python) random.seed(seed)
Example:
seed = 0
torch.manual_seed(seed)
with RNG(seed, devices=['cpu', 'cuda:0']):
print(torch.rand(1, device='cpu').item())
print(torch.rand(1, device='cuda:0').item())
print(torch.rand(1, device='cuda:1').item())
print('........')
print(torch.rand(1, device='cpu').item())
print(torch.rand(1, device='cuda:0').item())
print(torch.rand(1, device='cuda:1').item())
# Outputs
# 0.49625658988952637
# 0.08403993397951126
# 0.08403993397951126
# ........
# 0.49625658988952637
# 0.08403993397951126
# 0.41885504126548767
"""
def __init__(self, seed=None, devices=None):
"""Seeding the random number generator of torch devices.
Args:
seed: The intial seed value or list of values. If None, don't seed.
devices: List of devices to seed. If None, seed all devices.
"""
if devices is None:
num_gpus = torch.cuda.device_count()
devices = ['cpu'] + [f'cuda:{i}' for i in range(num_gpus)]
self.devices = [torch.device(d) for d in devices]
def is_iterable(value):
try:
iter(value)
return True
except:
return False
self.seed = seed if is_iterable(seed) else [seed] * len(self.devices)
def __enter__(self):
self.states = [(torch.random.get_rng_state() if d.type == 'cpu'
else torch.cuda.random.get_rng_state(d.index))
for d in self.devices]
for i, d in enumerate(self.devices):
if self.seed[i] is None:
continue
if d.type == 'cpu':
torch.default_generator.manual_seed(self.seed[i])
else:
with torch.cuda.random.device_ctx_manager(d.index):
torch.cuda.random.manual_seed(self.seed[i])
def __exit__(self, exce_type, exce_value, traceback):
for device, state in zip(self.devices, self.states):
if device.type == 'cpu':
torch.random.set_rng_state(state)
else:
torch.cuda.random.set_rng_state(state)
class DelayedRNGFork:
"""Caputres PyTorch's RNG state on initialization with delayed use
This context manager captures the torch random number generator states on
instantiation. Then, it can be used many times later using with-statements
Source: https://gist.github.com/xmodar/2328b13bdb11c6309ba449195a6b551a
Example:
rng = DelayedRNGFork(devices=[0])
print('1 outside', torch.randn(3, device='cuda:0'))
with rng:
print('1 inside ', torch.randn(3, device='cuda:0'))
print('2 inside ', torch.randn(3, device='cuda:0'))
print('2 outside', torch.randn(3, device='cuda:0'))
"""
def __init__(self, devices=None, enabled=True):
self.cpu_state = torch.get_rng_state() if enabled else None
self.gpu_states = {}
if self.enabled:
if devices is None:
devices = range(torch.cuda.device_count())
for device in map(torch.cuda.device, set(devices)):
self.gpu_states[device.idx] = torch.cuda.get_rng_state(device)
self._fork = None
def __enter__(self):
if self.enabled:
self._fork = torch.random.fork_rng(self.gpu_states)
self._fork.__enter__() # pylint: disable=no-member
torch.set_rng_state(self.cpu_state)
for device, state in self.gpu_states.items():
with torch.cuda.device(device):
torch.cuda.set_rng_state(state)
def __exit__(self, exc_type, exc_value, traceback):
if self.enabled:
fork, self._fork = self._fork, None
# pylint: disable=no-member
return fork.__exit__(exc_type, exc_value, traceback)
return None
@property
def enabled(self):
"""Whether the fork is enabled"""
return self.cpu_state is not None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment