Skip to content

Instantly share code, notes, and snippets.

@dkurt
Last active February 9, 2021 21:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dkurt/e37f59467926caca342033119cfc76a0 to your computer and use it in GitHub Desktop.
Save dkurt/e37f59467926caca342033119cfc76a0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable
torch.manual_seed(1412)
def fake_unpool(x, pool_inp, pool_out, indices):
# perform sanity check
ref = nn.MaxUnpool2d(2, stride=2)(x, indices)
pool_out_resized = nn.UpsamplingNearest2d(scale_factor=2)(pool_out)
x_resized = nn.UpsamplingNearest2d(scale_factor=2)(x)
mask = torch.abs(pool_out_resized - pool_inp) < 1e-6
out = x_resized * mask
assert torch.max(torch.abs(ref - out)) == 0
return out
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(4, 4, kernel_size=1, stride=1)
def forward(self, x):
y = self.conv1(x)
output, indices = self.pool(y)
conv = self.conv2(output)
return fake_unpool(conv, y, output, indices)
inp = Variable(torch.randn([1, 3, 6, 8]))
model = MyModel()
model.eval()
out = model(inp)
print(out.shape)
torch.onnx.export(model, inp, 'model.onnx', input_names=['input'], output_names=['output'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment