Skip to content

Instantly share code, notes, and snippets.

@curegit
Last active November 1, 2023 01:38
Show Gist options
  • Save curegit/7e417b7a02d71f9b5f4a7c3835a2e7aa to your computer and use it in GitHub Desktop.
Save curegit/7e417b7a02d71f9b5f4a7c3835a2e7aa to your computer and use it in GitHub Desktop.
PyTorch における Lanczos 補間による 2D マップ二倍拡大層の実装
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def lanczos(x, n):
return 0.0 if abs(x) > n else np.sinc(x) * np.sinc(x / n)
class Lanczos2xUpsampler(nn.Module):
def __init__(self, n=3):
super().__init__()
start = np.array([lanczos(i + 0.25, n) for i in range(-n, n)])
end = np.array([lanczos(i + 0.75, n) for i in range(-n, n)])
s = start / np.sum(start)
e = end / np.sum(end)
k1 = np.pad(s.reshape(1, n * 2) * s.reshape(n * 2, 1), ((0, 1), (0, 1)))
k2 = np.pad(e.reshape(1, n * 2) * s.reshape(n * 2, 1), ((0, 1), (1, 0)))
k3 = np.pad(s.reshape(1, n * 2) * e.reshape(n * 2, 1), ((1, 0), (0, 1)))
k4 = np.pad(e.reshape(1, n * 2) * e.reshape(n * 2, 1), ((1, 0), (1, 0)))
w = torch.tensor(np.array([[k1], [k2], [k3], [k4]], dtype=np.float32))
self.register_buffer('w', w)
self.n = n
def forward(self, x):
b, c, h, w = x.shape
h1 = x.view(b * c, 1, h, w)
h2 = F.pad(h1, (self.n, self.n, self.n, self.n), mode="reflect")
h3 = F.conv2d(h2, self.w)
h4 = F.pixel_shuffle(h3, 2)
return h4.view(b, c, h * 2, w * 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment