Last active
June 29, 2021 08:55
-
-
Save Ed-Optalysys/4e8fefdc4d2de61bca1aac0a8f12041c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pytorchlayer.opt_conv_layer import OptConvLayer | |
# makes a matrix sparse by padding within: e.g. [1,2,3] -> [0,1,0,2,0,3,0] | |
def pad_within(x, stride=2): | |
w = x.new_zeros(stride, stride) | |
w[0, 0] = 1 | |
return F.conv_transpose2d(x, w.expand(x.size(1), 1, stride, stride), stride=stride, groups=x.size(1)) | |
# takes a 4d array (BxCxHxW) where (channels // 4) == 0 | |
# returns a channel upsampled array with H and W x2 and C /4 | |
def channel_upsample2d(x): | |
start_shape = x.shape | |
target_shape = (start_shape[0], int(start_shape[1] / 4), start_shape[2]*2, start_shape[3]*2) | |
assert(x.ndim == 4) | |
assert(not (x.shape[1] % 4)) | |
x = pad_within(x, 2) | |
w = x.view(-1, 4, x.shape[2], x.shape[3]) | |
w[:, 1] = w[:, 1].roll(shifts=(1,0), dims=(1,2)) | |
w[:, 2] = w[:, 2].roll(shifts=(0,1), dims=(1,2)) | |
w[:, 3] = w[:, 3].roll(shifts=(1,1), dims=(1,2)) | |
w = w.sum(axis=1) | |
return w.view(target_shape) | |
# convolutional block that takes a 4d input (BxCxHxW), does multi channel-convolutions, internally increases the channel | |
# dimension to 4x the input channel dimension, then multichannel upsamples to output something with H & W 2x increased | |
class UpsampleBlock(nn.Module): | |
def __init__(self, depth, length, kernel_size, padding, optical=False, padding_mode='zeros') -> None: | |
super(UpsampleBlock, self).__init__() | |
self.conv_layer_count = length | |
self.depth = depth | |
if optical: | |
self.conv_layers = nn.ModuleList([ | |
OptConvLayer(depth if ch == 0 else depth * 4, depth * 4, kernel_size=kernel_size, padding=padding, | |
perfect_gradient=True) | |
for ch in range(self.conv_layer_count)]) | |
else: | |
self.conv_layers = nn.ModuleList([ | |
nn.Conv2d(depth if ch == 0 else depth * 4, depth * 4, kernel_size=kernel_size, padding=padding, | |
padding_mode=padding_mode) | |
for ch in range(self.conv_layer_count)]) | |
self.fix_initialisation() | |
def forward(self, x): | |
layer_input = x | |
for i, layer in enumerate(self.conv_layers): | |
x = F.elu(layer(x)) | |
# add residual connections | |
ind = torch.tensor([i for i in range(0, self.depth*4, 4)]) | |
for i in range(4): | |
x[:, ind+i] += layer_input | |
return channel_upsample2d(x) | |
# fix checkerboard pattern | |
def fix_initialisation(self): | |
with torch.no_grad(): | |
for q in range(self.conv_layers[self.conv_layer_count-1].weight.shape[0] // 4): | |
for ich in range(self.conv_layers[self.conv_layer_count-1].weight.shape[1]): | |
w0 = self.conv_layers[self.conv_layer_count-1].weight[q * 4, ich] | |
b0 = self.conv_layers[self.conv_layer_count-1].bias[q * 4] | |
for i in range(4): | |
self.conv_layers[self.conv_layer_count-1].weight[q * 4 + i, ich] = nn.Parameter(w0) | |
self.conv_layers[self.conv_layer_count - 1].bias[q * 4 + i] = nn.Parameter(b0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment