Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
U-Net in PyTorch
"""
U-Net architecture in PyTorch (https://arxiv.org/abs/1505.04597)
Author: Jacob Reinhold (jacob.reinhold@jhu.edu)
"""
import torch
from torch import nn
from torch.nn import functional as F
class ConvLayer(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
3, padding=1, bias=False))
self.add_module('norm', nn.BatchNorm2d(out_channels))
self.add_module('relu', nn.ReLU(inplace=True))
class UNetBlock(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
self.add_module('block1', ConvLayer(in_channels, out_channels))
self.add_module('block2', ConvLayer(out_channels, out_channels))
class UNet(nn.Module):
def __init__(self, in_channels:int, out_channels:int, channel_base:int=64):
super().__init__()
self.down_layers = nn.ModuleList([])
n_chan = lambda x: channel_base*2**x
self.down_layers.append(UNetBlock(in_channels, n_chan(0)))
for i in range(3):
self.down_layers.append(UNetBlock(n_chan(i), n_chan(i+1)))
self.bottleneck = UNetBlock(n_chan(3), n_chan(4))
self.up_layers = nn.ModuleList([])
for i in reversed(range(1, 4)):
self.up_layers.append(UNetBlock(n_chan(i+1)+n_chan(i), n_chan(i)))
self.up_layers.append(nn.Sequential(
UNetBlock(n_chan(1)+n_chan(0), n_chan(0),),
nn.Conv2d(n_chan(0), out_channels, 1)))
@staticmethod
def interp_cat(x, skip):
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True)
return torch.cat((x, skip), 1)
def forward(self, x):
skip_connections = []
for down_layer in self.down_layers:
x = down_layer(x)
skip_connections.append(x)
x = F.max_pool2d(x, 2)
x = self.bottleneck(x)
for up_layer in self.up_layers:
skip = skip_connections.pop()
x = self.interp_cat(x, skip)
x = up_layer(x)
return x
if __name__ == "__main__":
model = UNet(1,1)
print(model)
x = torch.randn(1,1,128,128)
model(x)
@jcreinhold

This comment has been minimized.

Copy link
Owner Author

@jcreinhold jcreinhold commented Jul 9, 2020

Implementation of a 2D U-Net in PyTorch. Differences from original: 1) uses linear interpolation instead of transposed conv. as upsampling, 2) maintains the input size by padding. Not tested extensively.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment