Skip to content

Instantly share code, notes, and snippets.

@nzw0301
Last active April 27, 2022 06:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nzw0301/9ba3c837a02539e8bde194f4f2c9d0e5 to your computer and use it in GitHub Desktop.
Save nzw0301/9ba3c837a02539e8bde194f4f2c9d0e5 to your computer and use it in GitHub Desktop.
"""
Modification version of https://github.com/optuna/optuna/pull/2303 with nccl backend
Optuna example that optimizes multi-layer perceptrons using PyTorch distributed.
In this example, we optimize the validation accuracy of hand-written digit recognition using
PyTorch distributed data parallel and MNIST. We optimize the neural network architecture as well
as the optimizer configuration. As it is too time consuming to use the whole MNIST dataset, we
here use a small subset of it.
You can execute this example with mpirun command as follows:
$ python -m torch.distributed.launch --nproc_per_node=2 pytorch_distributed_simple.py
Please note that you need to install PyTorch from source if you switch the communication backend
of torch.distributed to "mpi". Please refer to the following document for further details:
https://pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends
"""
import argparse
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import optuna
BATCHSIZE = 128
CLASSES = 10
DIR = os.getcwd()
EPOCHS = 30
LOG_INTERVAL = 10
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
def define_model(trial):
# We optimize the number of layers, hidden units and dropout ratio in each layer.
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, CLASSES))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
def get_mnist():
# Load MNIST dataset.
train_dataset = datasets.MNIST(DIR, train=True, transform=transforms.ToTensor())
train_dataset = torch.utils.data.Subset(train_dataset, indices=range(N_TRAIN_EXAMPLES))
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset)
valid_dataset = datasets.MNIST(DIR, train=False, transform=transforms.ToTensor())
valid_dataset = torch.utils.data.Subset(valid_dataset, indices=range(N_VALID_EXAMPLES))
valid_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=valid_dataset, shuffle=False
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=BATCHSIZE,
shuffle=False,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
sampler=valid_sampler,
batch_size=BATCHSIZE,
shuffle=False,
)
return train_loader, valid_loader, train_sampler, valid_sampler
def objective(single_trial):
trial = optuna.integration.TorchDistributedTrial(single_trial, rank)
# Generate the model.
model = DDP(define_model(trial).to(rank), device_ids=[rank])
# Generate the optimizers.
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
# Get the MNIST dataset.
train_loader, valid_loader, train_sampler, valid_sampler = get_mnist()
accuracy = 0
# Training of the model.
for epoch in range(EPOCHS):
model.train()
# Shuffle train dataset.
train_sampler.set_epoch(epoch)
for data, target in train_loader:
data, target = data.view(data.size(0), -1).to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
# Validation of the model.
model.eval()
correct = 0
with torch.no_grad():
for data, target in valid_loader:
data, target = data.view(data.size(0), -1).to(rank), target.to(rank)
output = model(data)
# Get the index of the max log-probability.
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
correct_tensor = torch.tensor([correct], dtype=torch.int).to(rank)
dist.all_reduce(correct_tensor)
total_correct = correct_tensor.item()
accuracy = total_correct / len(valid_loader.dataset)
trial.report(accuracy, epoch)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
# fix seed
seed = 7
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# or 2
world_size = 4
rank = args.local_rank
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
if rank == 0:
# Download dataset before starting the optimization.
datasets.MNIST(DIR, train=True, download=True)
dist.barrier()
study = None
n_trials = 20
if rank == 0:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=n_trials)
else:
for _ in range(n_trials):
try:
objective(None)
except optuna.TrialPruned:
pass
if rank == 0:
assert study is not None
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
@nzw0301
Copy link
Author

nzw0301 commented Feb 16, 2021

world_size = 2

time python -m torch.distributed.launch --nproc_per_node=2 pytorch_distributed_simple.py

Study statistics:
  Number of finished trials:  20
  Number of pruned trials:  7
  Number of complete trials:  13
Best trial:
  Value:  0.915625
  Params:
    n_layers: 1
    n_units_l0: 88
    dropout_l0: 0.45876637872650083
    optimizer: Adam
    lr: 0.003867116012603484

real    2m20.978s
user    4m34.364s
sys     0m5.414s

world_size = 4

time python -m torch.distributed.launch --nproc_per_node=4 pytorch_distributed_simple.py


Study statistics:
  Number of finished trials:  20
  Number of pruned trials:  6
  Number of complete trials:  14
Best trial:
  Value:  0.90234375
  Params:
    n_layers: 3
    n_units_l0: 62
    dropout_l0: 0.36208230683631953
    n_units_l1: 106
    dropout_l1: 0.3993010047710936
    n_units_l2: 67
    dropout_l2: 0.3332812025627451
    optimizer: Adam
    lr: 0.013772802610889121

real	1m25.108s
user	5m15.347s
sys	0m14.623s

@milliema
Copy link

How can we set up multiple gpus in the above code optuna_with_pytorch_distributed.py?

@nzw0301
Copy link
Author

nzw0301 commented Apr 27, 2022

The code above used multiple GPUs.

@milliema
Copy link

Thanks for your quick reply.
I do see the settings of rank and device under multiple gpu settings.
But when running the code we need to pass "--local_rank", does it mean only one gpu is used with the given local rank?

@nzw0301
Copy link
Author

nzw0301 commented Apr 27, 2022

Well, it is a convention of PyTorch's distributed training https://pytorch.org/docs/stable/elastic/run.html#launcher-api.

@milliema
Copy link

That's why! I'm using openmpi and the training command differs from PyTorch's distributed training.
Have you ever encountered pruning issue in the DDP training? The process terminates and reports error whenever prune happens.

@nzw0301
Copy link
Author

nzw0301 commented Apr 27, 2022

I didn't when I ran the example code in this gist code. However, the last time when I ran the code was Feb. 2021, so PyTorch and Optuna behaviours might differ from the latest stable.

@milliema
Copy link

Thank you for your help. I'm able to run the above code using torchrun on multiple gpus, will check whether the issue still happens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment