Skip to content

Instantly share code, notes, and snippets.

Created June 26, 2021 14:57
Show Gist options
  • Save jonashaag/01f184d8ae2b67819164091815d75b19 to your computer and use it in GitHub Desktop.
Save jonashaag/01f184d8ae2b67819164091815d75b19 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import *
def pointwise(in_channels, out_channels):
return Sequential(
Conv2d(in_channels, out_channels, 1, 1),
def depthwise(in_channels, out_channels, kernel_size, stride):
return Sequential(
Conv2d(in_channels, out_channels, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0), groups=in_channels),
class TRUNet(Module):
def __init__(self, in_channels=4, out_channels=10):
self.encoder = ModuleList([
Sequential(Conv2d(in_channels, 64, (5, 1), (2, 1), padding=(2, 0)), BatchNorm2d(64), ReLU()),
Sequential(pointwise( 64, 128), depthwise(128, 128, 3, 1)),
Sequential(pointwise(128, 128), depthwise(128, 128, 5, 2)),
Sequential(pointwise(128, 128), depthwise(128, 128, 3, 1)),
Sequential(pointwise(128, 128), depthwise(128, 128, 5, 2)),
Sequential(pointwise(128, 128), depthwise(128, 128, 3, 2)),
self.fgru = Sequential(
GRU(128, 64, bidirectional=True, batch_first=True),
pointwise(128, 64),
self.tgru = ModuleList([
GRU(64, 128, batch_first=True),
Linear(128, 64),
Sequential(BatchNorm2d(64), ReLU()),
self.decoder = Sequential(
Sequential(pointwise(64, 64), ConvTranspose2d(64, 64, (3, 1), (2, 1), padding=(1, 0), output_padding=(1, 0))),
Sequential(pointwise(192, 64), ConvTranspose2d(64, 64, (5, 1), (2, 1), padding=(2, 0), output_padding=(1, 0))),
Sequential(pointwise(192, 64), ConvTranspose2d(64, 64, (3, 1), (1, 1), padding=(1, 0))),
Sequential(pointwise(192, 64), ConvTranspose2d(64, 64, (5, 1), (2, 1), padding=(2, 0), output_padding=(1, 0))),
Sequential(pointwise(192, 64), ConvTranspose2d(64, 64, (3, 1), (1, 1), padding=(1, 0))),
Sequential(pointwise(128, out_channels), ConvTranspose2d(out_channels, out_channels, (5, 1), (2, 1), padding=(2, 0), output_padding=(1, 0))),
def forward(self, x: "(B, in_channels, 256, T)"):
batch, _, freqs, time = x.shape
if freqs == 257:
x = x[:, :, :256]
# Encoder
encoder_outs = []
for layer in self.encoder:
x = layer(x)
assert x.shape == (batch, 128, 16, time)
# FGRU block
fgru_gru, fgru_pointwise = self.fgru
fgru_gru_in = x.permute(0, 3, 2, 1).flatten(0, 1)
assert fgru_gru_in.shape == (batch * time, 16, 128)
fgru_gru_out, fgru_gru_state = fgru_gru(fgru_gru_in)
assert fgru_gru_out.shape == (batch * time, 16, 128)
fgru_pointwise_in = fgru_gru_out.reshape(batch, time, 16, 128).permute(0, 3, 2, 1)
assert fgru_pointwise_in.shape == (batch, 128, 16, time)
fgru_pointwise_out = fgru_pointwise(fgru_pointwise_in)
assert fgru_pointwise_out.shape == (batch, 64, 16, time)
# TGRU block
tgru_gru, tgru_linear, tgru_bnact = self.tgru
tgru_gru_in = fgru_pointwise_out.permute(0, 2, 3, 1).flatten(0, 1)
assert tgru_gru_in.shape == (batch * 16, time, 64)
tgru_gru_out, tgru_gru_state = tgru_gru(tgru_gru_in)
assert tgru_gru_out.shape == (batch * 16, time, 128)
tgru_linear_in = tgru_gru_out.reshape(batch, 16, time, 128)
assert tgru_linear_in.shape == (batch, 16, time, 128)
tgru_linear_out = tgru_linear(tgru_linear_in)
assert tgru_linear_out.shape == (batch, 16, time, 64)
tgru_bnact_in = tgru_linear_out.permute(0, 3, 1, 2)
assert tgru_bnact_in.shape == (batch, 64, 16, time)
tgru_bnact_out = tgru_bnact(tgru_bnact_in)
assert tgru_bnact_out.shape == (batch, 64, 16, time)
# Decoder
x = tgru_bnact_out
for i, (layer, skip_conn) in enumerate(zip(self.decoder, encoder_outs[::-1])):
if i:
x =[x, skip_conn], dim=1)
x = layer(x)
if freqs == 257:
x = functional.pad(x, [0, 0, 0, 1])
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment