Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Created February 28, 2022 14:25
Show Gist options
  • Save torridgristle/cbc46cc94b8af7190d22dc0be3ab9a64 to your computer and use it in GitHub Desktop.
Save torridgristle/cbc46cc94b8af7190d22dc0be3ab9a64 to your computer and use it in GitHub Desktop.
Sobel and Farid edge detection modules for PyTorch. Option for using Scharr kernel instead of Sobel is enabled by default and has better rotational symmetry.
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Sobel(nn.Module):
def __init__(self,structure=False,scharr=True, padding_mode='reflect'):
super().__init__()
self.structure = structure
self.padding_mode = padding_mode
if scharr == True:
self.kernel = torch.outer(torch.tensor([3., 10., 3.]) / 16,torch.tensor([1., 0., -1.]),).reshape(1,1,3,3).to(device)
else:
self.kernel = torch.tensor([[1.0, 2.0, 1.0],[0.00, 0.0, 0.00],[-1.0, -2.0, -1.0]]).reshape(1,1,3,3).to(device)
def forward(self, x):
x_pad = F.pad(x,[1,1,1,1],self.padding_mode)
x_x = F.conv2d(x_pad, self.kernel.expand(x.shape[1],1,3,3), groups=x.shape[1])
x_y = F.conv2d(x_pad, self.kernel.permute(0,1,3,2).expand(x.shape[1],1,3,3), groups=x.shape[1])
if self.structure == True:
x = torch.cat([x_x*x_x.abs(),x_y*x_y.abs(),x_x*x_y],1)
return x
else:
x = torch.cat([x_x,x_y,],1)
return x
class Farid(nn.Module):
def __init__(self, padding_mode='reflect'):
super().__init__()
self.padding_mode = padding_mode
p = torch.tensor([[0.0376593171958126, 0.249153396177344, 0.426374573253687,
0.249153396177344, 0.0376593171958126]])
d1 = torch.tensor([[0.109603762960254, 0.276690988455557, 0, -0.276690988455557,
-0.109603762960254]])
self.kernel = (d1.T * p).unsqueeze(0).unsqueeze(0).to(device)
def forward(self, x):
x_pad = F.pad(x,[2,2,2,2],self.padding_mode)
x_x = F.conv2d(x_pad, self.kernel.expand(x.shape[1],-1,-1,-1), groups=x.shape[1])
x_y = F.conv2d(x_pad, self.kernel.permute(0,1,3,2).expand(x.shape[1],-1,-1,-1), groups=x.shape[1])
x = torch.cat([x_x,x_y,],1)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment