Skip to content

Instantly share code, notes, and snippets.

@qmeeus
Last active March 8, 2021 23:05
Show Gist options
  • Save qmeeus/a1761836eecf382eb8b36ae218fa8dc8 to your computer and use it in GitHub Desktop.
Save qmeeus/a1761836eecf382eb8b36ae218fa8dc8 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
from sklearn import metrics
from skorch import NeuralNet
from skorch.callbacks import EarlyStopping, EpochScoring, LRScheduler
from torch.utils.data import DataLoader, Dataset
class AttentiveRecurrentDecoder(nn.Module):
def __init__(self,
input_dim,
output_dim,
d_model,
rnn_type="lstm",
num_layers=1,
rnn_dropout=0.1,
num_heads=8,
attn_dropout=0.1,
attn_bias=True):
super(AttentiveRecurrentDecoder, self).__init__()
RNNClass = getattr(nn, rnn_type.upper())
self.encoder = RNNClass(
input_dim,
d_model,
num_layers,
dropout=rnn_dropout,
bidirectional=True,
batch_first=True
)
self.attention = nn.MultiheadAttention(
embed_dim=d_model * 2,
num_heads=num_heads,
dropout=attn_dropout,
bias=attn_bias
)
self.output_layer = nn.Linear(d_model * 2, output_dim)
def forward(self, inputs, input_lengths, labels=None, return_attention=False):
packed_inputs = nn.utils.rnn.pack_padded_sequence(
inputs, input_lengths, batch_first=True, enforce_sorted=False)
packed_outputs, hidden = self.encoder(packed_inputs)
outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
if isinstance(hidden, tuple): # LSTM
hidden = hidden[1] # take the cell state
hidden = torch.cat([hidden[-1], hidden[-2]], dim=1)
query = hidden.unsqueeze(1).transpose(0, 1).contiguous()
key = outputs.transpose(0, 1).contiguous()
linear_combination, energy = self.attention(query, key, key)
linear_combination = linear_combination.squeeze(0)
logits = self.output_layer(linear_combination)
if return_attention:
return logits, energy
return logits
class SequenceDataset(Dataset):
def __init__(self, features, labels):
self.features = features
self.feature_lengths = [len(feats) for feats in self.features]
self.labels = labels
self.input_dim = self.features[0].shape[-1]
self.output_dim = self.labels[0].shape[-1]
def __getitem__(self, index):
inputs = torch.tensor(self.features[index])
input_lengths = torch.tensor(self.feature_lengths[index])
labels = torch.tensor(self.labels[index])
return inputs, input_lengths, labels
def __len__(self):
return len(self.features)
@staticmethod
def data_collator(batch):
"""
batch should be a list of (sequence, target, length) tuples...
Returns a padded tensor of sequences sorted from longest to shortest,
"""
features, lengths, labels = map(list, zip(*batch))
features = nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=0.)
lengths = torch.stack(lengths, 0)
labels = torch.stack(labels, 0)
# See https://skorch.readthedocs.io/en/latest/user/neuralnet.html#multiple-input-arguments
return {"inputs": features.float(), "input_lengths": lengths}, labels
def error_rate(y_true, y_pred):
assert y_true is not None
y_pred = y_pred[:, 1, :] > .5
return 1 - (y_true == y_pred).all(-1).mean()
def accuracy(y_true, y_pred):
assert y_true is not None
y_pred = y_pred[:, 1, :] > .5
return (y_true == y_pred).mean()
def generate_random_data():
M, s, S, D, K = 1000, 10, 50, 128, 32
input_lengths = np.random.randint(s, S, M)
features = [np.random.randn(l, D) for l in input_lengths]
target = np.zeros((M, K))
for i, ks in enumerate([np.random.choice(np.arange(0, K), 3, replace=False) for _ in range(M)]):
target[i, ks] = 1
return features, target
def main():
features, target = generate_random_data()
dataset = SequenceDataset(features, target)
net = NeuralNet(
module=AttentiveRecurrentDecoder,
module__input_dim=dataset.input_dim,
module__output_dim=dataset.output_dim,
module__d_model=64,
criterion=nn.BCEWithLogitsLoss,
iterator_train__collate_fn=dataset.data_collator,
iterator_valid__collate_fn=dataset.data_collator,
batch_size=64,
max_epochs=50,
lr=0.2,
callbacks=[
EpochScoring(scoring=metrics.make_scorer(error_rate), lower_is_better=True),
EpochScoring(scoring=metrics.make_scorer(accuracy), lower_is_better=False),
EarlyStopping(monitor="valid_loss", patience=5),
LRScheduler(policy="ReduceLROnPlateau", patience=3)
],
device="cuda"
)
net.fit(dataset)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment