Created
February 3, 2021 11:12
-
-
Save ilia-cher/8f655cf15beb1b11547fd3564a1c3958 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim | |
import torch.utils.data | |
import torchvision | |
import torchvision.transforms as T | |
import torchvision.datasets as datasets | |
import torchvision.models as models | |
import torch.profiler | |
model = models.resnet50(pretrained=True) | |
model.cuda() | |
cudnn.benchmark = True | |
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
download=True, transform=transform) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, | |
shuffle=True, num_workers=0) | |
criterion = nn.CrossEntropyLoss().cuda() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | |
device = torch.device("cuda:0") | |
model.train() | |
def output_fn(p): | |
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
p.export_chrome_trace("./worker0.pt.trace.json") | |
with torch.profiler.profile( | |
activities=[ | |
torch.profiler.ProfilerActivity.CPU, | |
torch.profiler.ProfilerActivity.CUDA], | |
schedule=torch.profiler.schedule( | |
wait=2, | |
warmup=2, | |
active=6), | |
on_trace_ready=output_fn, | |
record_shapes=True | |
) as p: | |
for step, data in enumerate(trainloader, 0): | |
print("step:{}".format(step)) | |
inputs, labels = data[0].to(device=device), data[1].to(device=device) | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
p.step() | |
if step + 1 >= 10: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment