Skip to content

Instantly share code, notes, and snippets.

@samson-wang
Created February 9, 2022 02:04
Show Gist options
  • Save samson-wang/39f759f994e23ce20e4f9b0598e162e3 to your computer and use it in GitHub Desktop.
Save samson-wang/39f759f994e23ce20e4f9b0598e162e3 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import time
class Net(nn.Module):
def __init__(self, ch=3):
super(Net, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(ch, 8, 3, padding=1),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.Conv2d(8, 32, 3, padding=1, groups=4),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32,32,1, padding=0, groups=4),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 3, 1),
)
def forward(self, x):
x = x.permute((0, 3, 1, 2))
x = self.model(x)
x = x.permute((0, 2, 3, 1))
return torch.clamp(x, min=0, max=1)
def score(batch_size=32, num_batches=10, width=10, height=1024):
model = Net()
print(model)
dry_run = 5
for i in range(dry_run + num_batches):
if i == dry_run:
st = time.time()
data = torch.rand((batch_size, height, width, 3))
with torch.no_grad():
out = model(data)
out.detach()
if (i+ 1) % 1000 == 0:
print((time.time() - st) / i * 1000, data.shape, out.shape)
print(out.shape)
return num_batches*batch_size/(time.time() - st)
def main():
import sys
print("Scoring batch = {} x {}" . format( 1, 30000))
print("{} images / second" . format(score( 1, 30000)))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment