Created
March 22, 2018 07:51
-
-
Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
how to make a bilateral filter using torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# torch_bilateral: bi/trilateral filtering in torch | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
from torch.nn import Parameter | |
import numpy as np | |
import pdb | |
import time | |
def gkern2d(l=21, sig=3): | |
"""Returns a 2D Gaussian kernel array.""" | |
ax = np.arange(-l // 2 + 1., l // 2 + 1.) | |
xx, yy = np.meshgrid(ax, ax) | |
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2)) | |
return kernel | |
class Shift(nn.Module): | |
def __init__(self, in_planes, kernel_size=3): | |
super(Shift, self).__init__() | |
self.in_planes = in_planes | |
self.kernel_size = kernel_size | |
self.channels_per_group = self.in_planes // (self.kernel_size ** 2) | |
if self.kernel_size == 3: | |
self.pad = 1 | |
elif self.kernel_size == 5: | |
self.pad = 2 | |
elif self.kernel_size == 7: | |
self.pad = 3 | |
def forward(self, x): | |
n, c, h, w = x.size() | |
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad)) | |
# Alias for convenience | |
cpg = self.channels_per_group | |
cat_layers = [] | |
for i in range(self.in_planes): | |
#Parse in row-major | |
for y in range(0,self.kernel_size): | |
y2 = y+h | |
for x in range(0, self.kernel_size): | |
x2 = x+w | |
xx = x_pad[:,i:i+1,y:y2,x:x2] | |
cat_layers += [xx] | |
return torch.cat(cat_layers, 1) | |
class BilateralFilter(nn.Module): | |
r"""BilateralFilter computes: | |
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||)) | |
""" | |
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1): | |
super(BilateralFilter, self).__init__() | |
#space gaussian kernel | |
#FIXME: do everything in torch | |
self.g = Parameter(torch.Tensor(channels,k*k)) | |
self.gw = gkern2d(k,sigma_space) | |
gw = np.tile(self.gw.reshape(channels,k*k,1,1),(1,1,height,width)) | |
self.g.data = torch.from_numpy(gw).float() | |
#shift | |
self.shift = Shift(channels,k) | |
self.sigma_color = 2*sigma_color**2 | |
def forward(self, I): | |
Is = self.shift(I).data | |
Iex = I.expand(*Is.size()) | |
D = (Is-Iex)**2 #here we are actually missing some sum over groups of channels | |
De = torch.exp(-D / self.sigma_color) | |
Dd = De * self.g.data | |
W_denom = torch.sum(Dd,dim=1) | |
If = torch.sum(Dd*Is,dim=1) / W_denom | |
return If | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
import cv2 | |
c,h,w = 1,480/2,640/2 | |
k = 5 | |
cuda = False | |
bilat = BilateralFilter(c,k,h,w) | |
if cuda: | |
bilat.cuda() | |
im = cv2.imread('/home/eperot/Pictures/Lena.png', cv2.IMREAD_GRAYSCALE) | |
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC) | |
im_in = im.reshape(1,1,h,w) | |
img = torch.from_numpy(im_in).float() / 255.0 | |
if cuda: | |
img_in = img.cuda() | |
else: | |
img_in = img | |
y = bilat(img_in) | |
start = time.time() | |
y = bilat(img_in) | |
print(time.time()-start) | |
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing! | |
show_out = cv2.resize(img_out,(640,480)) | |
show_in = cv2.resize(img[0,0].numpy(),(640,480)) | |
# diff = np.abs(img_out - img[0,0]) | |
# diff = (diff - diff.min()) / (diff.max() - diff.min()) | |
# cv2.namedWindow('diff') | |
# cv2.moveWindow('diff',50,50) | |
# cv2.imshow('diff', diff) | |
cv2.namedWindow('kernel') | |
cv2.imshow('kernel', bilat.gw) | |
cv2.namedWindow('img_out') | |
cv2.moveWindow('img_out', 200, 200) | |
cv2.imshow('img_out',show_out) | |
# | |
cv2.namedWindow('img_in') | |
cv2.imshow('img_in', show_in) | |
cv2.waitKey(0) | |
# n = gkern2d(5,3) | |
# | |
# s = n.reshape(1,25) | |
# | |
# print(n) | |
# print(s) | |
# plt.imshow(n, interpolation='none') | |
# plt.show() | |
I created an slightly updated version of this gist (everything is torch, no opencv, mitigated use of Parameter and Variable). The images look as expected and everything looks good. @etienne87 Do you agree?
import torch
from torch import nn
import torch.nn.functional as F
def gkern2d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
xx, yy = torch.meshgrid(ax, ax)
kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
for i in range(self.in_planes):
# Parse in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,i:i+1,y:y2,x:x2]
cat_layers += [xx]
return torch.cat(cat_layers, 1)
class BilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1, device='cpu'):
super(BilateralFilter, self).__init__()
# space gaussian kernel
self.gw = gkern2d(k, sigma_space, device=device)
self.g = torch.tile(self.gw.reshape(channels, k*k, 1, 1), (1, 1, height, width))
# shift
self.shift = Shift(channels, k)
self.sigma_color = 2*sigma_color**2
self.to(device=device)
def forward(self, I):
Is = self.shift(I).data
Iex = I.expand(*Is.size())
D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g
W_denom = torch.sum(Dd, dim=1)
If = torch.sum(Dd*Is, dim=1) / W_denom
return If
if __name__ == '__main__':
from skimage import data as ski_data
from skimage.color import rgb2gray
import time
k = 5
device = 'cuda'
img = ski_data.astronaut()
img = rgb2gray(img)
img = img[None, None, ...]
img = torch.tensor(img)
if device:
img = img.cuda()
bilat = BilateralFilter(img.shape[1], k, img.shape[2], img.shape[3], device=device)
start_time = time.time()
img_filtered = bilat(img)
print("Duration: ", time.time() - start_time)
img_out = img_filtered.cpu().numpy().squeeze()
looks fine :)
coming back to this thing several years later i'm wondering if pytorch's unfold is perhaps simpler than this "Shift" operator
how to change it for the rgb image case?
You would change the channel parameter to channels=3. However, I noticed that it will then throw an exception in both the original version and my modified version. Sadly, I have no understanding of how a bilateral filter works and am therefore unable to fix this error. I needed it myself only to implement a baseline.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thanks :) is this code useful for you?