Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
PyTorch implementation of Laplacian pyramid loss
import torch
def gauss_kernel(size=5, device=torch.device('cpu'), channels=3):
kernel = torch.tensor([[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]])
kernel /= 256.
kernel = kernel.repeat(channels, 1, 1, 1)
kernel = kernel.to(device)
return kernel
def downsample(x):
return x[:, :, ::2, ::2]
def upsample(x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
cc = cc.permute(0,1,3,2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
x_up = cc.permute(0,1,3,2)
return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device))
def conv_gauss(img, kernel):
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
return out
def laplacian_pyramid(img, kernel, max_levels=3):
current = img
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
down = downsample(filtered)
up = upsample(down)
diff = current-up
pyr.append(diff)
current = down
return pyr
class LapLoss(torch.nn.Module):
def __init__(self, max_levels=3, channels=3, device=torch.device('cpu')):
super(LapLoss, self).__init__()
self.max_levels = max_levels
self.gauss_kernel = gauss_kernel(channels=channels, device=device)
def forward(self, input, target):
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
@mlizhardy

This comment has been minimized.

Copy link

@mlizhardy mlizhardy commented Nov 6, 2020

thank you!

@mlizhardy

This comment has been minimized.

Copy link

@mlizhardy mlizhardy commented Nov 20, 2020

I noticed this only worked for square images. I think you need to swap your [2] and [3] channels for both lines 21 and 22 so it's like:

def upsample(x):
    cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
    cc = cc.permute(0,1,3,2)
    cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
    x_up = cc.permute(0,1,3,2)
    cv2.imwrite('test_Net2.png', x_up[0].permute(1, 2, 0).detach().cpu().numpy())
    input('break')
    return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device))
@alper111

This comment has been minimized.

Copy link
Owner Author

@alper111 alper111 commented Nov 20, 2020

Ah, I see, you are right. I will update it as soon as possible. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment