Skip to content

Instantly share code, notes, and snippets.

@dhgrs
Created June 1, 2020 02:23
Show Gist options
  • Save dhgrs/56424106e00bafee9617b0a15a028c2c to your computer and use it in GitHub Desktop.
Save dhgrs/56424106e00bafee9617b0a15a028c2c to your computer and use it in GitHub Desktop.
import argparse
import os
import random
import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.training.extensions as extensions
import pytorch_pfn_extras.training.triggers as triggers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.flatten(start_dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(manager, model, device, train_loader, optimizer):
while not manager.stop_trigger:
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
with manager.run_iteration():
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
ppe.reporting.report({"train/loss": loss.item()})
loss.backward()
optimizer.step()
def main():
# Training settings
device = torch.device(f"cuda:{os.environ['OMPI_COMM_WORLD_LOCAL_RANK']}")
torch.cuda.set_device(device)
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1234"
torch.distributed.init_process_group(backend="nccl", init_method="env://")
kwargs = {"num_workers": 1, "pin_memory": True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data", train=True, download=True, transform=transforms.ToTensor(),
),
batch_size=256,
**kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data", train=False, download=True, transform=transforms.ToTensor,
),
batch_size=1000,
**kwargs,
)
model = Net()
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])]
)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
manager = ppe.training.ExtensionsManager(
model, optimizer, 4, iters_per_epoch=len(train_loader),
)
def dummy_loss(manager):
dummy_loss = [["dummy", 1.0, 2.0, 1.0, 1.0], ["dummy", 1.1, 0.1, 1.1, 1.1]]
ppe.reporting.report(
{
"dummy/loss": dummy_loss[int(os.environ["OMPI_COMM_WORLD_RANK"])][
manager.epoch
]
}
)
manager.extend(dummy_loss, trigger=(1, "epoch"))
manager.extend(
ppe.training.extensions.snapshot(filename="snapshot_best", saver_rank=0),
trigger=triggers.MinValueTrigger("dummy/loss"),
)
manager.extend(extensions.ProgressBar())
train(manager, model, device, train_loader, optimizer)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment