Skip to content

Instantly share code, notes, and snippets.

@ihoromi4
Created February 5, 2020 11:49
Show Gist options
  • Star 39 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save ihoromi4/b681a9088f348942b01711f251e5f964 to your computer and use it in GitHub Desktop.
Save ihoromi4/b681a9088f348942b01711f251e5f964 to your computer and use it in GitHub Desktop.
pytorch - set seed everything
def seed_everything(seed: int):
import random, os
import numpy as np
import torch
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
seed_everything(42)
@jerpint
Copy link

jerpint commented Oct 3, 2022

Note that according to the pytorch documentation, torch.backends.cudnn.benchmark should be set to False, not True:

https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking

Disabling the benchmarking feature with torch.backends.cudnn.benchmark = False causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
However, if you do not need reproducibility across multiple executions of your application, then performance might improve if the benchmarking feature is enabled with torch.backends.cudnn.benchmark = True.

@elcolie
Copy link

elcolie commented Apr 28, 2023

Apple Chip
torch.mps.manual_seed(seed)

@peymanrostami
Copy link

you have forgotten to seed the generator function of the train dataloader. add this:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(seed)
train_loader = torch.utils.data.DataLoader(
        eval(dset_string)(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]),download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, worker_init_fn=seed_worker)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment