Skip to content

Instantly share code, notes, and snippets.

@rwightman
Last active March 15, 2024 12:55
Show Gist options
  • Star 73 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save rwightman/f2d3849281624be7c0f11c85c87c1598 to your computer and use it in GitHub Desktop.
Save rwightman/f2d3849281624be7c0f11c85c87c1598 to your computer and use it in GitHub Desktop.
PyTorch MedianPool (MedianFilter)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.
Args:
kernel_size: size of pooling kernel, int or 2-tuple
stride: pool stride, int or 2-tuple
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
same: override padding and enforce same padding, boolean
"""
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
super(MedianPool2d, self).__init__()
self.k = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _quadruple(padding) # convert to l, r, t, b
self.same = same
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
# using existing pytorch functions and tensor ops so that we get autograd,
# would likely be more efficient to implement from scratch at C/Cuda level
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
return x
@ingbeeedd
Copy link

ingbeeedd commented Dec 15, 2021

It's slow in CPU because of "Memory Copy" contiguous.

time : reshape = view (11s, size(224, 224, 3))

@fingertap
Copy link

fingertap commented Jul 7, 2022

In any means, unfold will not help you with any convolutional- or pooling-like operations if the stride is small (e.g., 1), as it will expand the raw data when you start calculation. The only way is to implement a CUDA extension to torch yourself, which is not very hard with the help of the official implementations of ops like max_pooling. It is very inconvenient though.

@GuillaumeTong
Copy link

For the reference of any future people looking into this, here is a slightly different implementation from Korina:
https://kornia.readthedocs.io/en/latest/_modules/kornia/filters/median.html
From what I understand, the main difference seems to be that they are using a convolution to do the unfolding instead. In my rough test, it seems both methods currently have roughly the same performance in terms of speed and memory

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment