Skip to content

Instantly share code, notes, and snippets.

@Ed-Optalysys
Last active June 29, 2021 09:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ed-Optalysys/3bb98736008b8ae668b0bc8d4636ff76 to your computer and use it in GitHub Desktop.
Save Ed-Optalysys/3bb98736008b8ae668b0bc8d4636ff76 to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch.nn.functional as F
from upsample import UpsampleBlock
from pytorchlayer.opt_conv_layer import OptConvLayer
class SuperResolutionModel(nn.Module):
in_channels = 3
out_channels = 3
upsample_block_depth = [4, 4, 4]
upsample_block_length = [4, 2, 1]
upsample_block_count = 3
upsample_kernel_size = 3
default_padding = 1
def __init__(self, optical=False):
super(SuperResolutionModel, self).__init__()
if optical:
self.c1 = OptConvLayer(self.in_channels, self.upsample_block_depth[0], kernel_size=self.upsample_kernel_size,
padding=self.default_padding, perfect_gradient=True)
self.cf = OptConvLayer(self.upsample_block_depth[-1], self.out_channels, kernel_size=1,
padding=0, perfect_gradient=True)
else:
self.c1 = nn.Conv2d(self.in_channels, self.upsample_block_depth[0], kernel_size=self.upsample_kernel_size,
padding=self.default_padding, padding_mode='replicate')
self.cf = nn.Conv2d(self.upsample_block_depth[-1], self.out_channels, kernel_size=1,
padding=0, padding_mode='replicate')
self.c_stem = nn.ModuleList([
UpsampleBlock(
self.upsample_block_depth[i], self.upsample_block_length[i], self.upsample_kernel_size,
self.default_padding, optical=optical, padding_mode='replicate')
for i in range(self.upsample_block_count)])
def forward(self, x):
x = F.elu(self.c1(x))
for block in self.c_stem:
x = block(x)
x = self.cf(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment