import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.modules.utils import _pair, _quadruple | |
class MedianPool2d(nn.Module): | |
""" Median pool (usable as median filter when stride=1) module. | |
Args: | |
kernel_size: size of pooling kernel, int or 2-tuple | |
stride: pool stride, int or 2-tuple | |
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad | |
same: override padding and enforce same padding, boolean | |
""" | |
def __init__(self, kernel_size=3, stride=1, padding=0, same=False): | |
super(MedianPool2d, self).__init__() | |
self.k = _pair(kernel_size) | |
self.stride = _pair(stride) | |
self.padding = _quadruple(padding) # convert to l, r, t, b | |
self.same = same | |
def _padding(self, x): | |
if self.same: | |
ih, iw = x.size()[2:] | |
if ih % self.stride[0] == 0: | |
ph = max(self.k[0] - self.stride[0], 0) | |
else: | |
ph = max(self.k[0] - (ih % self.stride[0]), 0) | |
if iw % self.stride[1] == 0: | |
pw = max(self.k[1] - self.stride[1], 0) | |
else: | |
pw = max(self.k[1] - (iw % self.stride[1]), 0) | |
pl = pw // 2 | |
pr = pw - pl | |
pt = ph // 2 | |
pb = ph - pt | |
padding = (pl, pr, pt, pb) | |
else: | |
padding = self.padding | |
return padding | |
def forward(self, x): | |
# using existing pytorch functions and tensor ops so that we get autograd, | |
# would likely be more efficient to implement from scratch at C/Cuda level | |
x = F.pad(x, self._padding(x), mode='reflect') | |
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) | |
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] | |
return x |
Great contribution! I would like to have a try.
Works well! Thank you !
I noticed that it is quite slow on CPU, ~6 seconds for 1024x1024 image. But passing to GPU gets it back to ~50ms. So nice work on using pytorch functions, we get GPU version for free !
Thank you for this @rwightman!
However, I'm having a memory issue when I try to use it.
I have a GeForce GTX TITAN X
and I am using your code with pytorch 1.1.0
(I cannot upgrade due to some compatibility isues.
On line 49, the operation tries to allocate 9x as much memory as the data passed to the class:
Edit: I'm not very familiarized with pytorch, but I guess the memory usage is being multiplied by 9 because each 3x3
patch (size of my kernel) is being recopied, correct?
Do you have any idea of solution?
Thank you for this @rwightman!
However, I'm having a memory issue when I try to use it.
I have aGeForce GTX TITAN X
and I am using your code with pytorch1.1.0
(I cannot upgrade due to some compatibility isues.On line 49, the operation tries to allocate 9x as much memory as the data passed to the class:
Edit: I'm not very familiarized with pytorch, but I guess the memory usage is being multiplied by 9 because each
3x3
patch (size of my kernel) is being recopied, correct?Do you have any idea of solution?
Maybe you need a smaller batch size which means less figures are sent to GPU at the same time.
This may help you: broadinstitute/CellBender#67
It's slow in CPU because of "Memory Copy" contiguous
.
time : reshape = view (11s, size(224, 224, 3))
Great. That's a good example
Thanks you!