Skip to content

Instantly share code, notes, and snippets.

@etienne87
Last active March 25, 2018 21:11
Show Gist options
  • Save etienne87/af4210586a1b5316e287479d512fc5e5 to your computer and use it in GitHub Desktop.
Save etienne87/af4210586a1b5316e287479d512fc5e5 to your computer and use it in GitHub Desktop.
torch bilateral rgb
#!/usr/bin/python
# torch_bilateral: bi/trilateral filtering in torch
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import pdb
import time
def gkern2d(l=21, sig=3):
"""Returns a 2D Gaussian kernel array."""
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
self.dilation = 1
self.kernel_size = kernel_size + 2 * (self.dilation-1)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, _, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
#Each shift contains all channels
#Parsed in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,:,:,y:y2,x:x2]
cat_layers += [xx]
out = torch.cat(cat_layers,2)
return out
class BilateralFilter(nn.Module):
r"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=80, sigma_color=0.1):
super(BilateralFilter, self).__init__()
#shift operation
self.shift = Shift(channels, k)
self.k = self.shift.kernel_size
#space gaussian kernel
self.g = Parameter(torch.Tensor(1,channels,self.k**2,1,1))
self.gw = gkern2d(self.k,sigma_space)
gw = self.gw.reshape(1,1,self.k**2,1,1).repeat(channels,1) #n x c x depth x 1 x 1
self.g.data = torch.from_numpy(gw).float()
self.sigma_color = 2*sigma_color**2
self.channels = channels
def forward(self, I):
I = I.view(1,self.channels,1,I.size(2),I.size(3))
Is = self.shift(I).data
D = (Is-I)**2
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g.data #broadcasting here
W_denom = torch.sum(Dd,dim=2)
If = torch.sum(Dd*Is,dim=2) / W_denom
return If
if __name__ == '__main__':
import matplotlib.pyplot as plt
import cv2
c,h,w = 3,350,640
k = 5
cuda = True
bilat = BilateralFilter(c,k,h,w)
if cuda:
bilat.cuda()
im = cv2.imread('/home/etienneperot/Images/Lena.jpg', cv2.IMREAD_GRAYSCALE if c == 1 else cv2.IMREAD_COLOR)
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC)
im_in = np.moveaxis(im,2,0)
im_in = im_in.reshape(1, c, h, w)
img = torch.from_numpy(im_in).float() / 255.0
# start = time.time()
# img_out2 = cv2.bilateralFilter(im.astype(np.float32)/255,k,0.1,3)
# print('opencv = ',time.time()-start)
start = time.time()
if cuda:
img_in = img.cuda()
else:
img_in = img
#start = time.time()
y = bilat(bilat(img_in))
print(time.time()-start)
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing!
img_out = np.moveaxis(img_out, 0, 2)
# #show_out2 = cv2.resize(img_out2,(640,480))
show_out = cv2.resize(img_out,(w,h))
show_in = im
# diff = np.abs(img_out - img[0,0])
# diff = (diff - diff.min()) / (diff.max() - diff.min())
# cv2.namedWindow('diff')
# cv2.moveWindow('diff',50,50)
# cv2.imshow('diff', diff)
cv2.namedWindow('kernel')
cv2.moveWindow('kernel', 0, 400)
cv2.imshow('kernel', bilat.gw)
cv2.namedWindow('img_in')
cv2.moveWindow('img_in', 600, 400)
cv2.imshow('img_in', show_in)
# cv2.namedWindow('img_out2')
# cv2.moveWindow('img_out2', 1200, 400)
# cv2.imshow('img_out2', show_out2)
cv2.namedWindow('img_out')
cv2.moveWindow('img_out', 1800, 400)
cv2.imshow('img_out', show_out)
cv2.waitKey(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment