Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Last active February 28, 2022 14:25
Show Gist options
  • Save torridgristle/68ce54f562ff46fd0bbb0381ea4ff243 to your computer and use it in GitHub Desktop.
Save torridgristle/68ce54f562ff46fd0bbb0381ea4ff243 to your computer and use it in GitHub Desktop.
Kaiser Filter Lowpass Module for PyTorch. Torchvision's gaussian blur uses the "reflect" padding mode but I'm not sure if that makes sense so I've set it for "replicate" by default.
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class KaiserLowpass(nn.Module):
def __init__(self, width=7, beta=11, periodic=False, padding_mode='replicate'):
super().__init__()
self.padding_mode = padding_mode
self.padding = 4*[(width-1)//2]
self.kernel = torch.kaiser_window(width,periodic,beta).reshape(1,1,1,width).to(device)
def forward(self, x):
b,c,h,w = x.shape
x = F.pad(x,self.padding,self.padding_mode)
x = F.conv2d(x, self.kernel.expand(c,-1,-1,-1), groups=c)
x = F.conv2d(x, self.kernel.permute(0,1,3,2).expand(c,-1,-1,-1), groups=c)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment