Skip to content

Instantly share code, notes, and snippets.

@AhmadMoussa
Last active March 20, 2020 19:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AhmadMoussa/73a412f0da20181d76b84a87cb48a9ad to your computer and use it in GitHub Desktop.
Save AhmadMoussa/73a412f0da20181d76b84a87cb48a9ad to your computer and use it in GitHub Desktop.
Minimal Code to create a scalable UNet in PyTorch.
import torch
from torch import nn
def convBlock(inc, outc, ksz, conv_or_deconv):
return nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=ksz,
stride=2) if conv_or_deconv else nn.ConvTranspose2d(in_channels=inc, out_channels=outc,
kernel_size=ksz, stride=2),
nn.LeakyReLU(),
nn.BatchNorm2d(num_features=outc)
)
class UNet(nn.Module):
def __init__(self, number_of_layers=6, ksz=(3, 3)):
super(UNet, self).__init__()
self.ksz = ksz
self.number_of_layers = number_of_layers
self.sizes = [(2 ** i, 2 ** (i + 1)) for i in range(0, self.number_of_layers)]
self.encoder_layers = nn.ModuleList(
[convBlock(inc, outc, self.ksz, 1) for i, (inc, outc) in enumerate(self.sizes)])
self.residuals = []
self.decoder_layers = nn.ModuleList(
[convBlock(2 * inc, outc, self.ksz, 0) for i, (outc, inc) in enumerate(list(reversed(self.sizes)))])
def forward(self, x):
for layer in self.encoder_layers:
x = layer(x)
print(x.shape)
self.residuals.append(x)
for residual, layer in zip(reversed(self.residuals[:]), self.decoder_layers):
x = torch.cat((residual, x), 1)
print(x.shape)
x = layer(x)
return x
unet = UNet(number_of_layers=6, ksz=(3, 3))
''' TEST:
import numpy as np
outputs = unet(torch.tensor(np.zeros((1, 1, 256, 256))).float())
print(outputs.shape)
'''
@gabrieldernbach
Copy link

gabrieldernbach commented Mar 20, 2020

Thanks for sharing your idea. I like your approach of using zip(reversed(residuals, decoder_layers).

However i don't think you should keep the residuals in self.residual.
Each time you make a forward pass you are appending to the list and it will keep growing.
Where do you reset it?

@AhmadMoussa
Copy link
Author

Hey, thanks for the comment. Yea you're right I don't reset it, whereas I should. I never actually used this as is, I originally wrote this as a personal memo on how to write an model in pytorch.

Thanks for the tip, I will fix it when I have some time. Out of curiosity, how did you find this? And what are you using the Unet for?

@gabrieldernbach
Copy link

I was looking for flexible/scalable unet implementations as I will have to search for appropriate proportions for my task at hand. I am applying it to audio segmentation.

I will push it to my git in a week or two

@AhmadMoussa
Copy link
Author

I'm not an expert, but maybe there are better choices for this task than a UNet. Depending on your requirements, maybe a Wavenet?

@gabrieldernbach
Copy link

gabrieldernbach commented Mar 20, 2020

Can you point me to a paper that elaborates on what you have in mind?
What makes you think that UNet is a bad choice for segmentation?

@AhmadMoussa
Copy link
Author

Well first I'd have to know what kind of segmentation you are doing? Are you separating different audio events from each other? Or segmenting multiple speakers? And what does your data look like? As said, I'm not an expert, but since a UNet uses regular convolutions you will need your data to have a specific shape or specific length in seconds if you're working in the time domain. If you're using spectrograms then you'll have to use a specific resolution and train on that. This limits the length (in seconds) of audio you can feed to your network. Also regular convolutions (1D or 2D) generally struggle with audio.

If you're not trying to write something that competes with SOTA papers then I think the UNet can be fine. Otherwise probably not.
Some papers you could check out:
Wavenet -> mainly for generative audio tasks
Adversarial Audio Synthesis -> they use unets to generate audio

I wrote a paper recently on generating audio with conditional GANs, but it's still in the process of being published. If you would like to talk more about this, shoot me an email ahmad.moussa@fuji.waseda.jp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment