Last active
November 20, 2021 03:44
-
-
Save xmodar/2328b13bdb11c6309ba449195a6b551a to your computer and use it in GitHub Desktop.
A more flexible context manager than `torch.random.fork_rng()` to preserve the state of the random number generator in PyTorch for the desired devices.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class RNG(): | |
"""Preserve the state of the random number generators of torch | |
https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a | |
Inspired by torch.random.fork_rng(). | |
Seeding random number generators (RNGs): | |
- (PyTorch) torch.manual_seed(seed) | |
- (Numpy) numpy.random.seed(seed) | |
- (Python) random.seed(seed) | |
Example: | |
seed = 0 | |
torch.manual_seed(seed) | |
with RNG(seed, devices=['cpu', 'cuda:0']): | |
print(torch.rand(1, device='cpu').item()) | |
print(torch.rand(1, device='cuda:0').item()) | |
print(torch.rand(1, device='cuda:1').item()) | |
print('........') | |
print(torch.rand(1, device='cpu').item()) | |
print(torch.rand(1, device='cuda:0').item()) | |
print(torch.rand(1, device='cuda:1').item()) | |
# Outputs | |
# 0.49625658988952637 | |
# 0.08403993397951126 | |
# 0.08403993397951126 | |
# ........ | |
# 0.49625658988952637 | |
# 0.08403993397951126 | |
# 0.41885504126548767 | |
""" | |
def __init__(self, seed=None, devices=None): | |
"""Seeding the random number generator of torch devices. | |
Args: | |
seed: The intial seed value or list of values. If None, don't seed. | |
devices: List of devices to seed. If None, seed all devices. | |
""" | |
if devices is None: | |
num_gpus = torch.cuda.device_count() | |
devices = ['cpu'] + [f'cuda:{i}' for i in range(num_gpus)] | |
self.devices = [torch.device(d) for d in devices] | |
def is_iterable(value): | |
try: | |
iter(value) | |
return True | |
except: | |
return False | |
self.seed = seed if is_iterable(seed) else [seed] * len(self.devices) | |
def __enter__(self): | |
self.states = [(torch.random.get_rng_state() if d.type == 'cpu' | |
else torch.cuda.random.get_rng_state(d.index)) | |
for d in self.devices] | |
for i, d in enumerate(self.devices): | |
if self.seed[i] is None: | |
continue | |
if d.type == 'cpu': | |
torch.default_generator.manual_seed(self.seed[i]) | |
else: | |
with torch.cuda.random.device_ctx_manager(d.index): | |
torch.cuda.random.manual_seed(self.seed[i]) | |
def __exit__(self, exce_type, exce_value, traceback): | |
for device, state in zip(self.devices, self.states): | |
if device.type == 'cpu': | |
torch.random.set_rng_state(state) | |
else: | |
torch.cuda.random.set_rng_state(state) | |
class DelayedRNGFork: | |
"""Caputres PyTorch's RNG state on initialization with delayed use | |
This context manager captures the torch random number generator states on | |
instantiation. Then, it can be used many times later using with-statements | |
Source: https://gist.github.com/xmodar/2328b13bdb11c6309ba449195a6b551a | |
Example: | |
rng = DelayedRNGFork(devices=[0]) | |
print('1 outside', torch.randn(3, device='cuda:0')) | |
with rng: | |
print('1 inside ', torch.randn(3, device='cuda:0')) | |
print('2 inside ', torch.randn(3, device='cuda:0')) | |
print('2 outside', torch.randn(3, device='cuda:0')) | |
""" | |
def __init__(self, devices=None, enabled=True): | |
self.cpu_state = torch.get_rng_state() if enabled else None | |
self.gpu_states = {} | |
if self.enabled: | |
if devices is None: | |
devices = range(torch.cuda.device_count()) | |
for device in map(torch.cuda.device, set(devices)): | |
self.gpu_states[device.idx] = torch.cuda.get_rng_state(device) | |
self._fork = None | |
def __enter__(self): | |
if self.enabled: | |
self._fork = torch.random.fork_rng(self.gpu_states) | |
self._fork.__enter__() # pylint: disable=no-member | |
torch.set_rng_state(self.cpu_state) | |
for device, state in self.gpu_states.items(): | |
with torch.cuda.device(device): | |
torch.cuda.set_rng_state(state) | |
def __exit__(self, exc_type, exc_value, traceback): | |
if self.enabled: | |
fork, self._fork = self._fork, None | |
# pylint: disable=no-member | |
return fork.__exit__(exc_type, exc_value, traceback) | |
return None | |
@property | |
def enabled(self): | |
"""Whether the fork is enabled""" | |
return self.cpu_state is not None | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment