Skip to content

Instantly share code, notes, and snippets.

@zarzen
Last active October 19, 2020 02:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zarzen/3d0a8b99ea038a3544c91e3305004f7b to your computer and use it in GitHub Desktop.
Save zarzen/3d0a8b99ea038a3544c91e3305004f7b to your computer and use it in GitHub Desktop.
singleNodeTraining
from torchvision import datasets, transforms, models
import torch
import torchvision
from torch import optim
import os
import torch.nn.functional as F
__n_threads = 4
print('torch num threads:', __n_threads)
torch.set_num_threads(__n_threads)
kwargs = {'num_workers': __n_threads, 'pin_memory': True}
def main():
model = models.vgg16_bn()
model = model.cuda()
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root="/tmp/data", train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=32, **kwargs)
optimizer = optim.SGD(model.parameters(), lr=0.01)
epoch = 15
for e in range(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print("epoch {}, batch {}, loss {}".format(e, batch_idx, loss.item()))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment