Last active
March 28, 2020 08:44
-
-
Save mkraemerx/786899549e9c66f52126f651705e9224 to your computer and use it in GitHub Desktop.
language-setting in torchtext
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tested with python 3, spaCy 2.0.16, pytorch 0.4.1, torchtext 0.3.1 | |
import os | |
from pathlib import Path | |
import urllib.request | |
import warnings | |
warnings.filterwarnings("ignore", message="numpy.dtype size changed") | |
warnings.filterwarnings("ignore", message="numpy.ufunc size changed") | |
import pandas as pd | |
import spacy | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchtext | |
from torchtext import data | |
IN_FILE = 'germeval2018.training.txt' | |
IN_FILE_ABS = os.path.abspath(IN_FILE) | |
IN_FILE_URL = 'https://raw.githubusercontent.com/uds-lsv/GermEval-2018-Data/master/germeval2018.training.txt' | |
IN_FILE_TEST = 'germeval2018.test.txt' | |
BATCH_SIZE = 16 | |
# check for train data | |
in_file = Path(IN_FILE_ABS) | |
if not in_file.is_file(): | |
urllib.request.urlretrieve(IN_FILE_URL, IN_FILE_ABS) | |
# prepare train data | |
df_trn = pd.read_csv(IN_FILE_ABS, sep='\t', header=None, names=[ | |
'text', 'label', 'detail']).drop('detail', axis=1) | |
df_trn['label'] = df_trn['label'].apply(lambda str: str == 'OFFENSE') | |
# prepare spacy | |
nlp = spacy.load('de') | |
def tokenize_fct(text): # create a tokenizer function | |
return [tok.text for tok in nlp.tokenizer(text)] | |
# define Fields in torchtext | |
text_field = data.Field(sequential=True, use_vocab=True, | |
lower=True, tokenize=tokenize_fct) | |
label_field = data.LabelField(dtype=torch.float) | |
fields = [('text', text_field), ('label', label_field)] | |
class DataframeDataset(torchtext.data.Dataset): | |
def __init__(self, df, fields, **kwargs): | |
examples = [] | |
for _, row in df.iterrows(): | |
values = [row[f[0]] for f in fields] | |
examples.append(torchtext.data.Example.fromlist( | |
values, fields)) | |
super().__init__(examples, fields, **kwargs) | |
@staticmethod | |
def sort_key(ex): return len(ex.text) | |
@classmethod | |
def splits(cls, fields, df_trn, df_val=None, df_test=None, **kwargs): | |
trn_data, val_data, tst_data = (None, None, None) | |
if df_trn is not None: | |
trn_data = cls(df_trn.copy(), fields, **kwargs) | |
if df_val is not None: | |
val_data = cls(df_val.copy(), fields, **kwargs) | |
if df_test is not None: | |
tst_data = cls(df_test.copy(), fields, **kwargs) | |
result = tuple(d for d in ( | |
trn_data, val_data, tst_data) if d is not None) | |
# similar to the torchtext version, return a scalar if only 1 element | |
return result if len(result) > 1 else result[0] | |
full_ds = DataframeDataset.splits(fields, df_trn) | |
trn_ds, val_ds = full_ds.split( | |
split_ratio=[0.8, 0.2], stratified=True) | |
trn_dl, val_dl = data.BucketIterator.splits([trn_ds, val_ds], batch_size=BATCH_SIZE, | |
sort_key=lambda t: len(t.text), sort_within_batch=False, repeat=False) | |
# download this from https://www.spinningbytes.com/resources/wordembeddings/ | |
vec = torchtext.vocab.Vectors('embed_tweets_de_100D_fasttext', | |
cache='/Users/michel/Downloads/') | |
# validation + test data should by no means influence the model, so build the vocab just on trn | |
text_field.build_vocab(trn_ds, vectors=vec) | |
#text_field.build_vocab(trn_ds, max_size=20000) | |
label_field.build_vocab(trn_ds) | |
print(f'text vocab size {len(text_field.vocab)}') | |
print(f'label vocab size {len(label_field.vocab)}') | |
class SimpleRNN(nn.Module): | |
def __init__(self, vocab_dim, emb_dim=100, hidden_dim=200): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_dim, emb_dim) | |
self.rnn = nn.RNN(emb_dim, hidden_dim) | |
self.fc = nn.Linear(hidden_dim, 1) # 1 is output dim | |
def forward(self, x): | |
# x type is Tensor[sentence len, batch size]. Internally pytorch does not use 1-hot | |
embedded = self.embedding(x) | |
# embedded type is Tensor[sentence len, batch size, emb dim] | |
output, hidden_state = self.rnn(embedded) | |
# output type is Tensor[sentence len, batch size, hidden dim] | |
# hidden_state type is Tensor[1, batch size, hidden dim] | |
return self.fc(hidden_state.squeeze(0)) | |
def binary_accuracy(preds, y): | |
""" | |
return accuracy per batch as ratio of correct/all | |
""" | |
# round predictions to the closest integer | |
rounded_preds = torch.round(torch.sigmoid(preds)) | |
# convert into float for division | |
pred_is_correct = (rounded_preds == y).float() | |
acc = pred_is_correct.sum()/len(pred_is_correct) | |
return acc | |
def train(model, iterator, optimizer, criterion, metric): | |
epoch_loss = 0 | |
epoch_meter = 0 | |
model.train() | |
for batch in iterator: | |
optimizer.zero_grad() | |
y_hat = model(batch.text).squeeze(1) | |
loss = criterion(y_hat, batch.label) | |
meter = metric(y_hat, batch.label) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
epoch_meter += meter.item() | |
return epoch_loss / len(iterator), epoch_meter / len(iterator) | |
def evaluate(model, iterator, criterion, metric): | |
epoch_loss = 0 | |
epoch_meter = 0 | |
model.eval() | |
with torch.no_grad(): | |
for batch in iterator: | |
y_hat = model(batch.text).squeeze(1) | |
loss = criterion(y_hat, batch.label) | |
meter = metric(y_hat, batch.label) | |
epoch_loss += loss.item() | |
epoch_meter += meter.item() | |
return epoch_loss / len(iterator), epoch_meter / len(iterator) | |
EMB_SIZE = 100 | |
HID_SIZE = 200 | |
NUM_LIN = 3 | |
NUM_EPOCH = 5 | |
# RNN variant SETUP | |
model = SimpleRNN(len(text_field.vocab), EMB_SIZE, HID_SIZE) | |
optimizer = optim.SGD(model.parameters(), lr=1e-3) | |
criterion = nn.BCEWithLogitsLoss() | |
# initialize embeddings with loaded vectors | |
model.embedding.weight.data.copy_(text_field.vocab.vectors) | |
# TRAINING | |
for epoch in range(NUM_EPOCH): | |
train_loss, train_acc = train( | |
model, trn_dl, optimizer, criterion, binary_accuracy) | |
print( | |
f'EPOCH: {epoch:02} - TRN_LOSS: {train_loss:.3f} - TRN_ACC: {train_acc*100:.2f}%') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment