Skip to content

Instantly share code, notes, and snippets.

@VictorSanh
Created August 14, 2018 15:36
Show Gist options
  • Save VictorSanh/c26a6e164b79be17cce452734c4c9d64 to your computer and use it in GitHub Desktop.
Save VictorSanh/c26a6e164b79be17cce452734c4c9d64 to your computer and use it in GitHub Desktop.
RE model - Reimplementation from G. Bekoulis
# coding: utf-8
import logging
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, Variable
from overrides import overrides
from allennlp.common import Params
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules.token_embedders import Embedding
from allennlp.modules import FeedForward
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder, SpanPruner
from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor, EndpointSpanExtractor
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.training.metrics import MentionRecall, ConllCorefScores
from jointmodelmd.training.metrics import RelationF1Measure
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
rel_type_2_idx = {"ORG-AFF": 0,
"PHYS": 1,
"ART": 2,
"PER-SOC": 3,
"PART-WHOLE": 4,
"GEN-AFF": 5}
idx_2_rel_type = {value: key for key, value in rel_type_2_idx.items()}
@Model.register("relation_extraction")
class RelationExtractor(Model):
"""
A class containing the scoring model for relation extraction.
It is derived the model proposed by Bekoulis G. in
"Joint entity recognition and relation extraction as a multi-head selection problem"
https://bekou.github.io/papers/eswa2018b/bekoulis_eswa_2018b.pdf
"""
def __init__(self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
context_layer: Seq2SeqEncoder,
d: int,
l: int,
n_classes: int,
activation: str = "relu",
label_namespace: str = "relation_ace_labels") -> None:
super(RelationExtractor, self).__init__(vocab)
self._U = Parameter(torch.Tensor(2*d, l))
self._W = Parameter(torch.Tensor(2*d, l))
self._V = Parameter(torch.Tensor(l, n_classes))
self._b = Parameter(torch.Tensor(l))
self._init_weights()
self._n_classes = n_classes
self._activation = activation
self._text_field_embedder = text_field_embedder
self._context_layer = context_layer
self._label_namespace = label_namespace
self._relation_metric = RelationF1Measure()
self._loss_fn = nn.BCEWithLogitsLoss()
def _init_weights(self):
"""
Initialization for the weights of the model.
"""
nn.init.xavier_uniform(self._U)
nn.init.xavier_uniform(self._W)
nn.init.xavier_uniform(self._V)
# nn.init.normal(self._U)
# nn.init.normal(self._W)
# nn.init.normal(self._V)
nn.init.normal(self._b)
def _multi_class_cross_entropy_loss(self,
scores,
labels,
mask):
"""
Compute the loss from
"""
labels = Variable(labels, requires_grad = False)
#Compute the mask before computing the loss
#Transform the mask that is at the sentence level (#Size: n_batches x padded_document_length)
#to a suitable format for the relation labels level
padded_document_length = mask.size(1)
mask = mask.float() #Size: n_batches x padded_document_length
squared_mask = torch.stack([e.view(padded_document_length, 1)*e for e in mask], dim = 0)
squared_mask = squared_mask.unsqueeze(-1).repeat(1,1,1,self._n_classes) #Size: n_batches x padded_document_length x padded_document_length x n_classes
#The scores (and gold labels) are flattened before using
#the binary cross entropy loss.
# We thus transform
flat_size = scores.size()
scores = scores*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes
scores_flat = scores.view(flat_size[0], flat_size[1], flat_size[2]*self._n_classes) #Size: n_batches x padded_document_length x (padded_document_length x n_classes)
labels = labels*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes
labels_flat = labels.view(flat_size[0], flat_size[1], flat_size[2]*self._n_classes) #Size: n_batches x padded_document_length x (padded_document_length x n_classes)
loss = self._loss_fn(scores_flat, labels_flat)
#Amplify the loss to actually see the figures move...
return 100*loss
@overrides
def forward(self, # type: ignore
text: Dict[str, torch.LongTensor],
relations: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Forward pass of the model.
Compute the predictions and the loss (if labels are available).
Parameters:
----------
text: Dict[str, torch.LongTensor]
The input sentences which have transformed into indexes (integers) according to a mapping token:str -> token:int
relations: torch.IntTensor
The gold relations to predict.
"""
#Text field embedder map the token:int to their word embedding representation token:embedding (whatever these embeddings are).
text_embeddings = self._text_field_embedder(text)
#Compute the mask from the text: 1 if there is actually a word in the corresponding sentence, 0 if it has been padded.
mask = util.get_text_field_mask(text) #Size: batch_size x padded_document_length
#Compute the contextualized representation from the word embeddings.
#Usually, _context_layer is a Seq2seq model such as LSTM
encoded_text = self._context_layer(text_embeddings, mask) #Size: batch_size x padded_document_length x lstm_output_size
###### Relation Scorer ##############
#Compute the relation scores
left = torch.matmul(encoded_text, self._U) #Size: batch_size x padded_document_length x l
right = torch.matmul(encoded_text, self._W) #Size: batch_size x padded_document_length x l
left = left.permute(1,0,2)
left = left.unsqueeze(3)
right = right.permute(0,2,1)
right = right.unsqueeze(0)
B = left + right
B = B.permute(1,0,3,2) #Size: batch_size x padded_document_length x padded_document_length x l
outer_sum_bias = B + self._b #Size: batch_size x padded_document_length x padded_document_length x l
if self._activation == "relu":
activated_outer_sum_bias = F.relu(outer_sum_bias)
elif self._activation == "tanh":
activated_outer_sum_bias = F.tanh(outer_sum_bias)
relation_scores = torch.matmul(activated_outer_sum_bias, self._V) #Size: batch_size x padded_document_length x padded_document_length x n_classes
#################################################################
batch_size, padded_document_length = mask.size()
relation_sigmoid_scores = F.sigmoid(relation_scores) #Size: batch_size x padded_document_length x padded_document_length x n_classes
#predicted_relations[l, i, j, k] == 1 iif we predict a relation k with ARG1==i, ARG2==j in the l-th sentence of the batch
predicted_relations = torch.round(relation_sigmoid_scores) #Size: batch_size x padded_document_length x padded_document_length x n_classes
output_dict = {
"relation_sigmoid_scores": relation_sigmoid_scores,
"predicted_relations": predicted_relations,
"mask": mask
}
if relations is not None:
#Reformat the gold relations before computing the loss
#Size: batch_size x padded_document_length x padded_document_length x n_classes
#gold_relations[l, i, j, k] == 1 iif we predict a relation k with ARG1==i, ARG2==j in the l-th sentence of the batch
gold_relations = torch.zeros(batch_size, padded_document_length, padded_document_length, self._n_classes)
for exple_idx, exple_tags in enumerate(relations): #going through the batch
#rel is a list of list containing the current sentence in the batch
#each sublist in rel is of size padded_document_length
#and encodes a relation in the sentence where the two non zeros elements
#indicate the two words arguments AND the relation type between these two words.
for rel in exple_tags:
#relations have been padded, so for each sentence in the batch there are
#max_nb_of_relations_in_batch_for_one_sentence relations ie (number of sublist such as rel)
#The padded relations are simply list of size padded_document_length filled with 0.
if rel.sum().data[0]==0: continue
for idx in rel.nonzero():
label_srt = self.vocab.get_token_from_index(rel[idx].data[0], self._label_namespace)
arg, rel_type = label_srt.split("_")
if arg == "ARG1": x = idx.data[0]
else: y = idx.data[0]
gold_relations[exple_idx, x, y, rel_type_2_idx[rel_type]] = 1
#GPU support
if text_embeddings.is_cuda: gold_relations = gold_relations.cuda()
#Compute the loss
output_dict["loss"] = self._multi_class_cross_entropy_loss(scores = relation_scores, labels = gold_relations, mask = mask)
#Compute the metrics with the predictions.
self._relation_metric(predictions = predicted_relations, gold_labels = gold_relations, mask = mask)
return output_dict
@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]):
"""
Decode the predictions
"""
decoded_predictions = []
for instance_tags in output_dict["predicted_relations"]:
sentence_length = instance_tags.size(0)
decoded_relations = []
for arg1, arg2, rel_type_idx in instance_tags.nonzero().data:
relation = ["*"]*sentence_length
rel_type = idx_2_rel_type[rel_type_idx]
relation[arg1] = "ARG1_" + rel_type
relation[arg2] = "ARG2_" + rel_type
decoded_relations.append(relation)
decoded_predictions.append(decoded_relations)
output_dict["decoded_predictions"] = decoded_predictions
return output_dict
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
"""
Compute the metrics for relation: precision, recall and f1.
A relation is considered correct if we can correctly predict the last word of ARG1, the last word of ARG2 and the relation type.
"""
metric_dict = self._relation_metric.get_metric(reset = reset)
return {x: y for x, y in metric_dict.items() if "overall" in x}
from typing import Dict, List, Optional, Set
from collections import defaultdict
import torch
from allennlp.common.checks import ConfigurationError
from allennlp.nn.util import get_lengths_from_binary_sequence_mask, ones_like
from allennlp.data.vocabulary import Vocabulary
from allennlp.training.metrics.metric import Metric
@Metric.register("relation_f1")
class RelationF1Measure(Metric):
"""
"""
def __init__(self) -> None:
"""
A class for computing the metrics specific to relation extraction.
We consider a relation correct if we correctly predict the last of the head of the two arguments and the relation type.
"""
#self._label_vocabulary = vocabulary.get_index_to_token_vocabulary(tag_namespace)
self._true_positives: int = 0
self._false_positives: int = 0
self._false_negatives: int = 0
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
Update the TP, FP and FN counters.
Parameters
----------
predictions : ``torch.Tensor``, required.
A tensor of predictions of shape (batch_size, sequence_length, num_classes).
gold_labels : ``torch.Tensor``, required.
A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
shape as the ``predictions`` tensor without the ``num_classes`` dimension.
mask: ``torch.Tensor``, optional (default = None).
A masking tensor the same size as ``gold_labels``.
"""
if mask is None:
mask = ones_like(gold_labels)
# Get the data from the Variables.
predictions, gold_labels, mask = self.unwrap_to_tensors(predictions,
gold_labels,
mask)
if (gold_labels.size() != predictions.size()):
raise ConfigurationError("Predictions and gold labels don't have the same size.")
#Apply mask
#Compute the mask before computing the loss
#Transform the mask that is at the sentence level (#Size: n_batches x padded_document_length)
#to a suitable format for the relation labels level
_, padded_document_length, _, n_classes = predictions.size()
mask = mask.float()
squared_mask = torch.stack([e.view(padded_document_length, 1)*e for e in mask], dim = 0)
squared_mask = squared_mask.unsqueeze(-1).repeat(1, 1, 1, n_classes) #Size: n_batches x padded_document_length x padded_document_length x n_classes
gold_labels = gold_labels.cpu()
predictions = predictions*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes
gold_labels = gold_labels*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes
# Iterate over timesteps in batch.
batch_size = gold_labels.size(0)
for i in range(batch_size):
flattened_predictions = predictions[i].view(-1).nonzero().cpu().numpy()
flattened_gold_labels = gold_labels[i].view(-1).nonzero().cpu().numpy()
for prediction in flattened_predictions:
if prediction in flattened_gold_labels:
self._true_positives += 1
else:
self._false_positives += 1
for gold in flattened_gold_labels:
if gold not in flattened_predictions:
self._false_negatives += 1
def get_metric(self, reset: bool = False):
"""
Get the metrics and reset the counters if necessary.
"""
all_metrics = {}
# Compute the precision, recall and f1 for all spans jointly.
precision, recall, f1_measure = self._compute_metrics(self._true_positives,
self._false_positives,
self._false_negatives)
all_metrics["precision-overall"] = precision
all_metrics["recall-overall"] = recall
all_metrics["f1-measure-overall"] = f1_measure
if reset:
self.reset()
return all_metrics
@staticmethod
def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int):
precision = float(true_positives) / float(true_positives + false_positives + 1e-13)
recall = float(true_positives) / float(true_positives + false_negatives + 1e-13)
f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
return precision, recall, f1_measure
def reset(self):
self._true_positives = 0
self._false_positives = 0
self._false_negatives = 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment