Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
PyTorch MedianPool (MedianFilter)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.
Args:
kernel_size: size of pooling kernel, int or 2-tuple
stride: pool stride, int or 2-tuple
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
same: override padding and enforce same padding, boolean
"""
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
super(MedianPool2d, self).__init__()
self.k = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _quadruple(padding) # convert to l, r, t, b
self.same = same
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
# using existing pytorch functions and tensor ops so that we get autograd,
# would likely be more efficient to implement from scratch at C/Cuda level
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
return x
@Coder-Liuu
Copy link

Coder-Liuu commented Mar 13, 2020

Great. That's a good example
Thanks you!

@zhuogege1943
Copy link

zhuogege1943 commented Apr 7, 2021

Great contribution! I would like to have a try.

@cipri-tom
Copy link

cipri-tom commented Apr 19, 2021

Works well! Thank you !

I noticed that it is quite slow on CPU, ~6 seconds for 1024x1024 image. But passing to GPU gets it back to ~50ms. So nice work on using pytorch functions, we get GPU version for free !

@jpcbertoldo
Copy link

jpcbertoldo commented May 19, 2021

Thank you for this @rwightman!

However, I'm having a memory issue when I try to use it.
I have a GeForce GTX TITAN X and I am using your code with pytorch 1.1.0 (I cannot upgrade due to some compatibility isues.

On line 49, the operation tries to allocate 9x as much memory as the data passed to the class:

image

Edit: I'm not very familiarized with pytorch, but I guess the memory usage is being multiplied by 9 because each 3x3 patch (size of my kernel) is being recopied, correct?

Do you have any idea of solution?

@Auzzer
Copy link

Auzzer commented Sep 27, 2021

Thank you for this @rwightman!

However, I'm having a memory issue when I try to use it.
I have a GeForce GTX TITAN X and I am using your code with pytorch 1.1.0 (I cannot upgrade due to some compatibility isues.

On line 49, the operation tries to allocate 9x as much memory as the data passed to the class:

image

Edit: I'm not very familiarized with pytorch, but I guess the memory usage is being multiplied by 9 because each 3x3 patch (size of my kernel) is being recopied, correct?

Do you have any idea of solution?

Maybe you need a smaller batch size which means less figures are sent to GPU at the same time.
This may help you: broadinstitute/CellBender#67

@ingbeeedd
Copy link

ingbeeedd commented Dec 15, 2021

It's slow in CPU because of "Memory Copy" contiguous.

time : reshape = view (11s, size(224, 224, 3))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment