Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Last active January 25, 2019 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ptrblck/a237d16b3560da2d18dd10b8bac8ad1f to your computer and use it in GitHub Desktop.
Save ptrblck/a237d16b3560da2d18dd10b8bac8ad1f to your computer and use it in GitHub Desktop.
import time
import torch
import torch.nn as nn
device = 'cuda:0'
batch_size = 10
channels = 64
h, w = 128, 128
x = torch.randn(batch_size, channels, h, w, device=device)
# Setup conv
conv = nn.Conv2d(64, 1, 1, bias=False).to(device)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(500):
output_conv = conv(x)
torch.cuda.synchronize()
t1 = time.time()
print('Conv2d took {} seconds'.format(t1 - t0))
# Setup linear
lin = nn.Linear(64, 1, bias=False).to(device)
with torch.no_grad():
lin.weight = nn.Parameter(conv.weight.view(1, channels))
torch.cuda.synchronize()
t0 = time.time()
for _ in range(500):
output_lin = lin(x.view(batch_size, channels, -1).transpose(-2, -1)).transpose(-1, -2)
output_lin = output_lin.view(batch_size, 1, h, w)
torch.cuda.synchronize()
t1 = time.time()
print('Linear took {} seconds'.format(t1 - t0))
print('Maximal error: {}'.format(torch.max(output_conv - output_lin)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment