Skip to content

Instantly share code, notes, and snippets.

@louity
Last active April 19, 2022 13:34
Show Gist options
  • Save louity/b4f328be78960755fdb3bf92ee7898b2 to your computer and use it in GitHub Desktop.
Save louity/b4f328be78960755fdb3bf92ee7898b2 to your computer and use it in GitHub Desktop.
Gaussian filtering in pytorch
import torch
import matplotlib.pyplot as plt
inp = torch.FloatTensor(1,1,32,32).uniform_(-1,1)
plt.imshow(inp[0,0])
plt.show()
# noyau gaussien
gauss_ker_7 = torch.FloatTensor(1,1,7,7)
x,y = torch.meshgrid(torch.linspace(-3,3,7), torch.linspace(-3,3,7), indexing='xy')
gauss_ker_7[0,0] = torch.exp(-0.5*((x**2 + y**2)))
gauss_ker_7 /= gauss_ker_7.sum()
# filtrage sans padding
inp_filt = torch.nn.functional.conv2d(inp, gauss_ker_7)
plt.imshow(inp_filt[0,0])
plt.show()
# filtrage avec padding
inp_pad_filt = torch.nn.functional.conv2d(
torch.nn.functional.pad(inp, (3,3,3,3), mode='reflect'),
gauss_ker_7)
plt.imshow(inp_pad_filt[0,0])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment