Skip to content

Instantly share code, notes, and snippets.

@zou3519

zou3519/bench.py Secret

Last active November 11, 2022 16:35
Show Gist options
  • Save zou3519/98e69289ba28f80247039723d073ef07 to your computer and use it in GitHub Desktop.
Save zou3519/98e69289ba28f80247039723d073ef07 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import vmap, combine_state_for_ensemble
from torch.utils.benchmark import Timer
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 400)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
device = 'cuda'
models = [LeNet5().to(device) for _ in range(5)]
fmodel, params, buffers = combine_state_for_ensemble(models)
# cifar10 dataset: batch size 128, 3 channel, 32x32 images
data = torch.randn(128, 3, 32, 32).to(device)
def vmap_inference():
results = vmap(fmodel, (0, 0, None))(params, buffers, data)
return results
def forloop_inference():
results = []
for model in models:
results.append(model(data))
return torch.stack(results)
t0 = Timer('vmap_inference()', setup='from __main__ import vmap_inference')
t1 = Timer('forloop_inference()', setup='from __main__ import forloop_inference')
print(t0.timeit(1000))
print(t1.timeit(1000))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment