Skip to content

Instantly share code, notes, and snippets.

@Dref360
Created June 16, 2017 15:08
Show Gist options
  • Save Dref360/3bf411b34a02301c3f43403130074ae0 to your computer and use it in GitHub Desktop.
Save Dref360/3bf411b34a02301c3f43403130074ae0 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
class Layer(nn.Module):
def __init__(self, in_, out, kernel, init_kernel='uniform', padding=1):
super(Layer, self).__init__()
self.bn = nn.BatchNorm2d(in_)
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(in_, out, kernel, padding=padding)
if init_kernel == 'uniform':
init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0))
else:
init.kaiming_normal(self.conv1.weight)
self.drop = nn.Dropout2d(0.25)
def forward(self, x):
x = self.bn(x)
x = self.relu(x)
x = self.conv1(x)
x = self.drop(x)
return x
class DenseBlock(nn.Module):
def __init__(self, in_, out, steps, init_kernel='uniform'):
super(DenseBlock, self).__init__()
self.layers = nn.ModuleList([])
for i in range(steps):
self.layers.append(Layer(in_, out, (3, 3), init_kernel))
in_ += out
def forward(self, x):
layers = []
for i, layer in enumerate(self.layers):
l = layer(x)
layers.append(l)
x = torch.cat([x, l], 1)
self.layers_out = layers
return x
class TransitionDown(nn.Module):
def __init__(self, in_, out, init_kernel='uniform'):
super(TransitionDown, self).__init__()
self.layer = Layer(in_, out, (1, 1), init_kernel=init_kernel, padding=0)
self.max_pool = nn.MaxPool2d((2, 2))
def forward(self, x):
x = self.layer(x)
# print("Before Pool : ", x.size())
x = self.max_pool(x)
return x
class TransitionUp(nn.Module):
def __init__(self, in_, out, init_kernel='uniform'):
super(TransitionUp, self).__init__()
self.deconv = nn.ConvTranspose2d(in_, out, (3, 3), stride=2, padding=1, output_padding=1)
if init_kernel == 'uniform':
init.xavier_uniform(self.deconv.weight, gain=np.sqrt(2.0))
else:
init.kaiming_normal(self.deconv.weight)
def forward(self, skip, blocks):
x = torch.cat(blocks, 1)
# print("Before Deconv : ", x.size())
# print("Skip : ", skip.size())
x = self.deconv(x)
# print("After Deconv : ", x.size())
return torch.cat([x] + [skip], 1)
class Tiramisu(nn.Module):
def __init__(self, in_, first_out, steps, out, third_output, init_kernel, last_step=15):
super(Tiramisu, self).__init__()
self.first_conv = nn.Conv2d(in_, first_out, (3, 3), padding=1)
acc_dense = [DenseBlock(first_out, out, steps[0], init_kernel)]
f = 112
acc_down = [TransitionDown(f, f)]
for s in steps[1:]:
acc_dense.append(DenseBlock(f, out, s))
f += (s * out)
acc_down.append(TransitionDown(f, f))
self.down_dense = nn.ModuleList(acc_dense)
self.down_down = nn.ModuleList(acc_down)
self.middle = nn.ModuleList([])
middle_out = f
for i in range(last_step):
l = Layer(middle_out, out, (3, 3), init_kernel)
middle_out += 16
self.middle.append(l)
acc_up = []
acc_dense_up = []
n_layers_per_block = [last_step, ] + steps[::-1]
for n_layers, s in zip(n_layers_per_block, steps[::-1]):
n_filters_keep = out * n_layers
bottom_out = out * n_layers
up = TransitionUp(bottom_out, n_filters_keep)
dense_up = DenseBlock(middle_out, out, s)
middle_out -= bottom_out
acc_up.append(up)
acc_dense_up.append(dense_up)
self.up_up = nn.ModuleList(acc_up)
self.up_dense = nn.ModuleList(acc_dense_up)
def forward(self, x):
x = self.first_conv(x)
going_down = []
for dense, down in zip(self.down_dense, self.down_down):
x = dense(x)
going_down.append(x)
x = down(x)
going_down = going_down[::-1]
upsample = []
for mid in self.middle:
l = mid(x)
upsample.append(l)
x = torch.cat([l, x], 1)
for up, dense, skip in zip(self.up_up, self.up_dense, going_down):
x = up(skip, upsample)
x = dense(x)
upsample = dense.layers_out
return x
model = Tiramisu(3, 48, [4, 5, 7, 10, 12], 16, False, 'uniform')
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
y_pred = model(Variable(torch.randn(1, 3, 224, 224)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment