Skip to content

Instantly share code, notes, and snippets.

@AhmadMoussa
Last active March 20, 2020 19:25
Show Gist options
  • Save AhmadMoussa/73a412f0da20181d76b84a87cb48a9ad to your computer and use it in GitHub Desktop.
Save AhmadMoussa/73a412f0da20181d76b84a87cb48a9ad to your computer and use it in GitHub Desktop.
Minimal Code to create a scalable UNet in PyTorch.
import torch
from torch import nn
def convBlock(inc, outc, ksz, conv_or_deconv):
return nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=ksz,
stride=2) if conv_or_deconv else nn.ConvTranspose2d(in_channels=inc, out_channels=outc,
kernel_size=ksz, stride=2),
nn.LeakyReLU(),
nn.BatchNorm2d(num_features=outc)
)
class UNet(nn.Module):
def __init__(self, number_of_layers=6, ksz=(3, 3)):
super(UNet, self).__init__()
self.ksz = ksz
self.number_of_layers = number_of_layers
self.sizes = [(2 ** i, 2 ** (i + 1)) for i in range(0, self.number_of_layers)]
self.encoder_layers = nn.ModuleList(
[convBlock(inc, outc, self.ksz, 1) for i, (inc, outc) in enumerate(self.sizes)])
self.residuals = []
self.decoder_layers = nn.ModuleList(
[convBlock(2 * inc, outc, self.ksz, 0) for i, (outc, inc) in enumerate(list(reversed(self.sizes)))])
def forward(self, x):
for layer in self.encoder_layers:
x = layer(x)
print(x.shape)
self.residuals.append(x)
for residual, layer in zip(reversed(self.residuals[:]), self.decoder_layers):
x = torch.cat((residual, x), 1)
print(x.shape)
x = layer(x)
return x
unet = UNet(number_of_layers=6, ksz=(3, 3))
''' TEST:
import numpy as np
outputs = unet(torch.tensor(np.zeros((1, 1, 256, 256))).float())
print(outputs.shape)
'''
@AhmadMoussa
Copy link
Author

Well first I'd have to know what kind of segmentation you are doing? Are you separating different audio events from each other? Or segmenting multiple speakers? And what does your data look like? As said, I'm not an expert, but since a UNet uses regular convolutions you will need your data to have a specific shape or specific length in seconds if you're working in the time domain. If you're using spectrograms then you'll have to use a specific resolution and train on that. This limits the length (in seconds) of audio you can feed to your network. Also regular convolutions (1D or 2D) generally struggle with audio.

If you're not trying to write something that competes with SOTA papers then I think the UNet can be fine. Otherwise probably not.
Some papers you could check out:
Wavenet -> mainly for generative audio tasks
Adversarial Audio Synthesis -> they use unets to generate audio

I wrote a paper recently on generating audio with conditional GANs, but it's still in the process of being published. If you would like to talk more about this, shoot me an email ahmad.moussa@fuji.waseda.jp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment