Last active
March 25, 2018 21:11
-
-
Save etienne87/af4210586a1b5316e287479d512fc5e5 to your computer and use it in GitHub Desktop.
torch bilateral rgb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# torch_bilateral: bi/trilateral filtering in torch | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
from torch.nn import Parameter | |
import numpy as np | |
import pdb | |
import time | |
def gkern2d(l=21, sig=3): | |
"""Returns a 2D Gaussian kernel array.""" | |
ax = np.arange(-l // 2 + 1., l // 2 + 1.) | |
xx, yy = np.meshgrid(ax, ax) | |
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2)) | |
return kernel | |
class Shift(nn.Module): | |
def __init__(self, in_planes, kernel_size=3): | |
super(Shift, self).__init__() | |
self.in_planes = in_planes | |
self.kernel_size = kernel_size | |
self.channels_per_group = self.in_planes // (self.kernel_size ** 2) | |
self.dilation = 1 | |
self.kernel_size = kernel_size + 2 * (self.dilation-1) | |
if self.kernel_size == 3: | |
self.pad = 1 | |
elif self.kernel_size == 5: | |
self.pad = 2 | |
elif self.kernel_size == 7: | |
self.pad = 3 | |
def forward(self, x): | |
n, c, _, h, w = x.size() | |
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad)) | |
# Alias for convenience | |
cpg = self.channels_per_group | |
cat_layers = [] | |
#Each shift contains all channels | |
#Parsed in row-major | |
for y in range(0,self.kernel_size): | |
y2 = y+h | |
for x in range(0, self.kernel_size): | |
x2 = x+w | |
xx = x_pad[:,:,:,y:y2,x:x2] | |
cat_layers += [xx] | |
out = torch.cat(cat_layers,2) | |
return out | |
class BilateralFilter(nn.Module): | |
r"""BilateralFilter computes: | |
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||)) | |
""" | |
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=80, sigma_color=0.1): | |
super(BilateralFilter, self).__init__() | |
#shift operation | |
self.shift = Shift(channels, k) | |
self.k = self.shift.kernel_size | |
#space gaussian kernel | |
self.g = Parameter(torch.Tensor(1,channels,self.k**2,1,1)) | |
self.gw = gkern2d(self.k,sigma_space) | |
gw = self.gw.reshape(1,1,self.k**2,1,1).repeat(channels,1) #n x c x depth x 1 x 1 | |
self.g.data = torch.from_numpy(gw).float() | |
self.sigma_color = 2*sigma_color**2 | |
self.channels = channels | |
def forward(self, I): | |
I = I.view(1,self.channels,1,I.size(2),I.size(3)) | |
Is = self.shift(I).data | |
D = (Is-I)**2 | |
De = torch.exp(-D / self.sigma_color) | |
Dd = De * self.g.data #broadcasting here | |
W_denom = torch.sum(Dd,dim=2) | |
If = torch.sum(Dd*Is,dim=2) / W_denom | |
return If | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
import cv2 | |
c,h,w = 3,350,640 | |
k = 5 | |
cuda = True | |
bilat = BilateralFilter(c,k,h,w) | |
if cuda: | |
bilat.cuda() | |
im = cv2.imread('/home/etienneperot/Images/Lena.jpg', cv2.IMREAD_GRAYSCALE if c == 1 else cv2.IMREAD_COLOR) | |
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC) | |
im_in = np.moveaxis(im,2,0) | |
im_in = im_in.reshape(1, c, h, w) | |
img = torch.from_numpy(im_in).float() / 255.0 | |
# start = time.time() | |
# img_out2 = cv2.bilateralFilter(im.astype(np.float32)/255,k,0.1,3) | |
# print('opencv = ',time.time()-start) | |
start = time.time() | |
if cuda: | |
img_in = img.cuda() | |
else: | |
img_in = img | |
#start = time.time() | |
y = bilat(bilat(img_in)) | |
print(time.time()-start) | |
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing! | |
img_out = np.moveaxis(img_out, 0, 2) | |
# #show_out2 = cv2.resize(img_out2,(640,480)) | |
show_out = cv2.resize(img_out,(w,h)) | |
show_in = im | |
# diff = np.abs(img_out - img[0,0]) | |
# diff = (diff - diff.min()) / (diff.max() - diff.min()) | |
# cv2.namedWindow('diff') | |
# cv2.moveWindow('diff',50,50) | |
# cv2.imshow('diff', diff) | |
cv2.namedWindow('kernel') | |
cv2.moveWindow('kernel', 0, 400) | |
cv2.imshow('kernel', bilat.gw) | |
cv2.namedWindow('img_in') | |
cv2.moveWindow('img_in', 600, 400) | |
cv2.imshow('img_in', show_in) | |
# cv2.namedWindow('img_out2') | |
# cv2.moveWindow('img_out2', 1200, 400) | |
# cv2.imshow('img_out2', show_out2) | |
cv2.namedWindow('img_out') | |
cv2.moveWindow('img_out', 1800, 400) | |
cv2.imshow('img_out', show_out) | |
cv2.waitKey(0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment