Skip to content

Instantly share code, notes, and snippets.

@Ed-Optalysys
Last active June 29, 2021 08:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ed-Optalysys/4e8fefdc4d2de61bca1aac0a8f12041c to your computer and use it in GitHub Desktop.
Save Ed-Optalysys/4e8fefdc4d2de61bca1aac0a8f12041c to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchlayer.opt_conv_layer import OptConvLayer
# makes a matrix sparse by padding within: e.g. [1,2,3] -> [0,1,0,2,0,3,0]
def pad_within(x, stride=2):
w = x.new_zeros(stride, stride)
w[0, 0] = 1
return F.conv_transpose2d(x, w.expand(x.size(1), 1, stride, stride), stride=stride, groups=x.size(1))
# takes a 4d array (BxCxHxW) where (channels // 4) == 0
# returns a channel upsampled array with H and W x2 and C /4
def channel_upsample2d(x):
start_shape = x.shape
target_shape = (start_shape[0], int(start_shape[1] / 4), start_shape[2]*2, start_shape[3]*2)
assert(x.ndim == 4)
assert(not (x.shape[1] % 4))
x = pad_within(x, 2)
w = x.view(-1, 4, x.shape[2], x.shape[3])
w[:, 1] = w[:, 1].roll(shifts=(1,0), dims=(1,2))
w[:, 2] = w[:, 2].roll(shifts=(0,1), dims=(1,2))
w[:, 3] = w[:, 3].roll(shifts=(1,1), dims=(1,2))
w = w.sum(axis=1)
return w.view(target_shape)
# convolutional block that takes a 4d input (BxCxHxW), does multi channel-convolutions, internally increases the channel
# dimension to 4x the input channel dimension, then multichannel upsamples to output something with H & W 2x increased
class UpsampleBlock(nn.Module):
def __init__(self, depth, length, kernel_size, padding, optical=False, padding_mode='zeros') -> None:
super(UpsampleBlock, self).__init__()
self.conv_layer_count = length
self.depth = depth
if optical:
self.conv_layers = nn.ModuleList([
OptConvLayer(depth if ch == 0 else depth * 4, depth * 4, kernel_size=kernel_size, padding=padding,
perfect_gradient=True)
for ch in range(self.conv_layer_count)])
else:
self.conv_layers = nn.ModuleList([
nn.Conv2d(depth if ch == 0 else depth * 4, depth * 4, kernel_size=kernel_size, padding=padding,
padding_mode=padding_mode)
for ch in range(self.conv_layer_count)])
self.fix_initialisation()
def forward(self, x):
layer_input = x
for i, layer in enumerate(self.conv_layers):
x = F.elu(layer(x))
# add residual connections
ind = torch.tensor([i for i in range(0, self.depth*4, 4)])
for i in range(4):
x[:, ind+i] += layer_input
return channel_upsample2d(x)
# fix checkerboard pattern
def fix_initialisation(self):
with torch.no_grad():
for q in range(self.conv_layers[self.conv_layer_count-1].weight.shape[0] // 4):
for ich in range(self.conv_layers[self.conv_layer_count-1].weight.shape[1]):
w0 = self.conv_layers[self.conv_layer_count-1].weight[q * 4, ich]
b0 = self.conv_layers[self.conv_layer_count-1].bias[q * 4]
for i in range(4):
self.conv_layers[self.conv_layer_count-1].weight[q * 4 + i, ich] = nn.Parameter(w0)
self.conv_layers[self.conv_layer_count - 1].bias[q * 4 + i] = nn.Parameter(b0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment