Last active
June 11, 2022 21:38
-
-
Save IFeelBloated/17db77a4936a4cbdda91e52b40e598e9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import numpy | |
from torch_utils.ops import upfirdn2d | |
Box = [1.0] | |
Linear = [0.25,0.50,0.25] | |
Quadratic = [0.128,0.235,0.276,0.235,0.128] | |
Cubic = [0.058,0.128,0.199,0.231,0.199,0.128,0.058] | |
Gaussian = [0.008,0.036,0.110,0.213,0.267,0.213,0.110,0.036,0.008] | |
Mitchell1 = [-0.008,-0.011,0.019,0.115,0.237,0.296,0.237,0.115,0.019,-0.011,-0.008] | |
Sinc = [-0.003,-0.013,0.000,0.094,0.253,0.337,0.253,0.094,0.000,-0.013,-0.003] | |
Lanczos4 = [-0.008,0.000,0.095,0.249,0.327,0.249,0.095,0.000,-0.008] | |
Lanczos5 = [-0.005,-0.022,0.000,0.108,0.256,0.327,0.256,0.108,0.000,-0.022,-0.005] | |
def CreateLowpassKernel(Weights, Inplace): | |
Kernel = numpy.array([Weights]) if Inplace else numpy.convolve(Weights, [1, 1]).reshape(1, -1) | |
Kernel = torch.Tensor(Kernel.T @ Kernel) | |
return Kernel / torch.sum(Kernel) | |
class InterpolativeUpsamplerReference(nn.Module): | |
def __init__(self, Filter): | |
super(InterpolativeUpsamplerReference, self).__init__() | |
Kernel = 4 * CreateLowpassKernel(Filter, Inplace=False) | |
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1])) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
y = nn.functional.conv_transpose2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=2, padding=3 * self.FilterRadius) | |
return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3]) | |
class InterpolativeDownsamplerReference(nn.Module): | |
def __init__(self, Filter): | |
super(InterpolativeDownsamplerReference, self).__init__() | |
Kernel = CreateLowpassKernel(Filter, Inplace=False) | |
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1])) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
y = nn.functional.conv2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=2) | |
return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3]) | |
class InplaceUpsamplerReference(nn.Module): | |
def __init__(self, Filter): | |
super(InplaceUpsamplerReference, self).__init__() | |
Kernel = CreateLowpassKernel(Filter, Inplace=True) | |
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1])) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
x = nn.functional.pixel_shuffle(x, 2) | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
return nn.functional.conv2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=1).view(*x.shape) | |
class InplaceDownsamplerReference(nn.Module): | |
def __init__(self, Filter): | |
super(InplaceDownsamplerReference, self).__init__() | |
Kernel = CreateLowpassKernel(Filter, Inplace=True) | |
self.register_buffer('Kernel', Kernel.view(1, 1, Kernel.shape[0], Kernel.shape[1])) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
y = nn.functional.conv2d(y.view(y.shape[0] * y.shape[1], 1, y.shape[2], y.shape[3]), self.Kernel, stride=1).view(*x.shape) | |
return nn.functional.pixel_unshuffle(y, 2) | |
class InterpolativeUpsamplerCUDA(nn.Module): | |
def __init__(self, Filter): | |
super(InterpolativeUpsamplerCUDA, self).__init__() | |
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False)) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
return upfirdn2d.upsample2d(y, self.Kernel, padding=-2 * self.FilterRadius) | |
class InterpolativeDownsamplerCUDA(nn.Module): | |
def __init__(self, Filter): | |
super(InterpolativeDownsamplerCUDA, self).__init__() | |
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False)) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
return upfirdn2d.downsample2d(y, self.Kernel, padding=-self.FilterRadius) | |
class InplaceUpsamplerCUDA(nn.Module): | |
def __init__(self, Filter): | |
super(InplaceUpsamplerCUDA, self).__init__() | |
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True)) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
x = nn.functional.pixel_shuffle(x, 2) | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
return upfirdn2d.upfirdn2d(y, self.Kernel) | |
class InplaceDownsamplerCUDA(nn.Module): | |
def __init__(self, Filter): | |
super(InplaceDownsamplerCUDA, self).__init__() | |
self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True)) | |
self.FilterRadius = len(Filter) // 2 | |
def forward(self, x): | |
y = nn.functional.pad(x, (self.FilterRadius, self.FilterRadius, self.FilterRadius, self.FilterRadius), mode='reflect') | |
y = upfirdn2d.upfirdn2d(y, self.Kernel) | |
return nn.functional.pixel_unshuffle(y, 2) | |
InterpolativeUpsampler = InterpolativeUpsamplerCUDA | |
InterpolativeDownsampler = InterpolativeDownsamplerCUDA | |
InplaceUpsampler = InplaceUpsamplerCUDA | |
InplaceDownsampler = InplaceDownsamplerCUDA | |
## quick test ## | |
from skimage import io, img_as_float32 | |
import matplotlib.pyplot as plt | |
x = img_as_float32(io.imread(r'C:\Users\Administrator\Desktop\mandril.png')) | |
x = torch.Tensor(x).transpose(0, 2).transpose(1, 2) | |
x = x.view(1, *x.shape) #NCHW | |
x = InterpolativeUpsampler(Lanczos5)(x) | |
x = InterpolativeDownsampler(Gaussian)(x) | |
y = x.view(x.shape[0], x.shape[1], 1, x.shape[2], x.shape[3]) | |
x = torch.cat([y, y, y, y], 2).view(x.shape[0], 4 * x.shape[1], x.shape[2], x.shape[3]) | |
x = InplaceUpsampler(Sinc)(x) | |
x = InplaceDownsampler(Cubic)(x) | |
x = x.view(x.shape[0], x.shape[1] // 4, 4, x.shape[2], x.shape[3]).mean(2) | |
plt.imshow(x[0].transpose(0, 2).transpose(0, 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment