Skip to content

Instantly share code, notes, and snippets.

@IFeelBloated
Last active June 11, 2022 21:38
Show Gist options
  • Save IFeelBloated/17db77a4936a4cbdda91e52b40e598e9 to your computer and use it in GitHub Desktop.
Save IFeelBloated/17db77a4936a4cbdda91e52b40e598e9 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import numpy
from torch_utils.ops import upfirdn2d
Box = [1.0]
Linear = [0.25,0.50,0.25]
Quadratic = [0.128,0.235,0.276,0.235,0.128]
Cubic = [0.058,0.128,0.199,0.231,0.199,0.128,0.058]
Gaussian = [0.008,0.036,0.110,0.213,0.267,0.213,0.110,0.036,0.008]
Mitchell1 = [-0.008,-0.011,0.019,0.115,0.237,0.296,0.237,0.115,0.019,-0.011,-0.008]
Sinc = [-0.003,-0.013,0.000,0.094,0.253,0.337,0.253,0.094,0.000,-0.013,-0.003]
Lanczos4 = [-0.008,0.000,0.095,0.249,0.327,0.249,0.095,0.000,-0.008]
Lanczos5 = [-0.005,-0.022,0.000,0.108,0.256,0.327,0.256,0.108,0.000,-0.022,-0.005]
def CreateLowpassKernel(Weights, Inplace):
Kernel = numpy.array([Weights]) if Inplace else numpy.convolve(Weights, [1, 1]).reshape(1, -1)
Kernel = torch.Tensor(Kernel.T @ Kernel)
return Kernel / torch.sum(Kernel)
class InterpolativeUpsamplerReference(nn.Module):
def __init__(self, Filter):
super(InterpolativeUpsamplerReference, self).__init__()
Kernel = 4 * CreateLowpassKernel(Filter, Inplace=False)
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1]))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
y = nn.functional.conv_transpose2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=2, padding=3 * self.FilterRadius)
return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3])
class InterpolativeDownsamplerReference(nn.Module):
def __init__(self, Filter):
super(InterpolativeDownsamplerReference, self).__init__()
Kernel = CreateLowpassKernel(Filter, Inplace=False)
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1]))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
y = nn.functional.conv2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=2)
return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3])
class InplaceUpsamplerReference(nn.Module):
def __init__(self, Filter):
super(InplaceUpsamplerReference, self).__init__()
Kernel = CreateLowpassKernel(Filter, Inplace=True)
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1]))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
x = nn.functional.pixel_shuffle(x, 2)
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
return nn.functional.conv2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=1).view(*x.shape)
class InplaceDownsamplerReference(nn.Module):
def __init__(self, Filter):
super(InplaceDownsamplerReference, self).__init__()
Kernel = CreateLowpassKernel(Filter, Inplace=True)
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1]))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
y = nn.functional.conv2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=1).view(*x.shape)
return nn.functional.pixel_unshuffle(y, 2)
class InterpolativeUpsamplerCUDA(nn.Module):
def __init__(self, Filter):
super(InterpolativeUpsamplerCUDA, self).__init__()
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
return upfirdn2d.upsample2d(y, self.Kernel, padding=-2 * self.FilterRadius)
class InterpolativeDownsamplerCUDA(nn.Module):
def __init__(self, Filter):
super(InterpolativeDownsamplerCUDA, self).__init__()
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
return upfirdn2d.downsample2d(y, self.Kernel, padding=-self.FilterRadius)
class InplaceUpsamplerCUDA(nn.Module):
def __init__(self, Filter):
super(InplaceUpsamplerCUDA, self).__init__()
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
x = nn.functional.pixel_shuffle(x, 2)
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
return upfirdn2d.upfirdn2d(y, self.Kernel)
class InplaceDownsamplerCUDA(nn.Module):
def __init__(self, Filter):
super(InplaceDownsamplerCUDA, self).__init__()
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True))
self.FilterRadius = len(Filter) // 2
def forward(self, x):
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect')
y = upfirdn2d.upfirdn2d(y, self.Kernel)
return nn.functional.pixel_unshuffle(y, 2)
InterpolativeUpsampler = InterpolativeUpsamplerCUDA
InterpolativeDownsampler = InterpolativeDownsamplerCUDA
InplaceUpsampler = InplaceUpsamplerCUDA
InplaceDownsampler = InplaceDownsamplerCUDA
## quick test ##
from skimage import io, img_as_float32
import matplotlib.pyplot as plt
x = img_as_float32(io.imread(r'C:\Users\Administrator\Desktop\mandril.png'))
x = torch.Tensor(x).transpose(0, 2).transpose(1, 2)
x = x.view(1, *x.shape) #NCHW
x = InterpolativeUpsampler(Lanczos5)(x)
x = InterpolativeDownsampler(Gaussian)(x)
y = x.view(x.shape[0], x.shape[1], 1, x.shape[2], x.shape[3])
x = torch.cat([y, y, y, y], 2).view(x.shape[0], 4 * x.shape[1], x.shape[2], x.shape[3])
x = InplaceUpsampler(Sinc)(x)
x = InplaceDownsampler(Cubic)(x)
x = x.view(x.shape[0], x.shape[1] // 4, 4, x.shape[2], x.shape[3]).mean(2)
plt.imshow(x[0].transpose(0, 2).transpose(0, 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment