Skip to content

Instantly share code, notes, and snippets.

@moritzschaefer
Created January 17, 2018 10:15
Show Gist options
  • Save moritzschaefer/70aa5527fe64d746bf36044f43a45564 to your computer and use it in GitHub Desktop.
Save moritzschaefer/70aa5527fe64d746bf36044f43a45564 to your computer and use it in GitHub Desktop.
Combination of LSTM and ConvLayer
from torch import nn
import torch
from torch.nn.init import kaiming_normal, normal
def weights_init(m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
kaiming_normal(m.weight.data)
try:
kaiming_normal(m.bias.data)
except ValueError:
normal(m.bias.data)
class Deep1(nn.Module):
'''
A combination of an LSTM and Conv layers
4 is the nucleotide encoding (4 bit per nucleotide)
30 is the length of our input sequence
120 is 4*30 the number of sequence features
input_size is the number of total input_features.
The first 120 have to be the sequence 1-hot-encodings
'''
lstm_hidden = 50
kernel_size = 4
def __init__(self, input_size):
super(Deep1, self).__init__()
self.lstm = nn.LSTM(input_size=4, hidden_size=self.lstm_hidden, num_layers=2,
dropout=False, bidirectional=False) # TODO enable?
self.conv1 = nn.Conv1d(
in_channels=4, out_channels=4, kernel_size=self.kernel_size)
# hidden layers, additional_features, conv output
self.fc1 = nn.Linear(
self.lstm_hidden + (input_size - 120) + 4 * (30 - self.kernel_size + 1), 1)
self.apply(weights_init)
def forward(self, x):
nuc_features, additional_features = x.split(120, dim=1)
nuc_features.contiguous()
# lstm needs form (seq_len, batch, input_size)
lstm_input = nuc_features.view(-1, 30, 4).permute(1, 0, 2)
# return only last seq-output. Form: (batch_size x lstm_hidden)
lstm_output = self.lstm(lstm_input)[0][-1, :, :]
# batch_size x 4 x 27 (30-kernel_size+1)
conv1_output = self.conv1(
nuc_features.view(-1, 30, 4).permute(0, 2, 1))
# TODO add max-pooling
conv1_output = conv1_output.view(-1, 4 * (30 - self.kernel_size + 1))
return self.fc1(torch.cat([lstm_output, additional_features, conv1_output], 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment