RE model - Reimplementation from G. Bekoulis
# coding: utf-8 | |
import logging | |
import math | |
from typing import Any, Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter, Variable | |
from overrides import overrides | |
from allennlp.common import Params | |
from allennlp.data import Vocabulary | |
from allennlp.models.model import Model | |
from allennlp.modules.token_embedders import Embedding | |
from allennlp.modules import FeedForward | |
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder, SpanPruner | |
from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor, EndpointSpanExtractor | |
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator | |
from allennlp.training.metrics import MentionRecall, ConllCorefScores | |
from jointmodelmd.training.metrics import RelationF1Measure | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
rel_type_2_idx = {"ORG-AFF": 0, | |
"PHYS": 1, | |
"ART": 2, | |
"PER-SOC": 3, | |
"PART-WHOLE": 4, | |
"GEN-AFF": 5} | |
idx_2_rel_type = {value: key for key, value in rel_type_2_idx.items()} | |
@Model.register("relation_extraction") | |
class RelationExtractor(Model): | |
""" | |
A class containing the scoring model for relation extraction. | |
It is derived the model proposed by Bekoulis G. in | |
"Joint entity recognition and relation extraction as a multi-head selection problem" | |
https://bekou.github.io/papers/eswa2018b/bekoulis_eswa_2018b.pdf | |
""" | |
def __init__(self, | |
vocab: Vocabulary, | |
text_field_embedder: TextFieldEmbedder, | |
context_layer: Seq2SeqEncoder, | |
d: int, | |
l: int, | |
n_classes: int, | |
activation: str = "relu", | |
label_namespace: str = "relation_ace_labels") -> None: | |
super(RelationExtractor, self).__init__(vocab) | |
self._U = Parameter(torch.Tensor(2*d, l)) | |
self._W = Parameter(torch.Tensor(2*d, l)) | |
self._V = Parameter(torch.Tensor(l, n_classes)) | |
self._b = Parameter(torch.Tensor(l)) | |
self._init_weights() | |
self._n_classes = n_classes | |
self._activation = activation | |
self._text_field_embedder = text_field_embedder | |
self._context_layer = context_layer | |
self._label_namespace = label_namespace | |
self._relation_metric = RelationF1Measure() | |
self._loss_fn = nn.BCEWithLogitsLoss() | |
def _init_weights(self): | |
""" | |
Initialization for the weights of the model. | |
""" | |
nn.init.xavier_uniform(self._U) | |
nn.init.xavier_uniform(self._W) | |
nn.init.xavier_uniform(self._V) | |
# nn.init.normal(self._U) | |
# nn.init.normal(self._W) | |
# nn.init.normal(self._V) | |
nn.init.normal(self._b) | |
def _multi_class_cross_entropy_loss(self, | |
scores, | |
labels, | |
mask): | |
""" | |
Compute the loss from | |
""" | |
labels = Variable(labels, requires_grad = False) | |
#Compute the mask before computing the loss | |
#Transform the mask that is at the sentence level (#Size: n_batches x padded_document_length) | |
#to a suitable format for the relation labels level | |
padded_document_length = mask.size(1) | |
mask = mask.float() #Size: n_batches x padded_document_length | |
squared_mask = torch.stack([e.view(padded_document_length, 1)*e for e in mask], dim = 0) | |
squared_mask = squared_mask.unsqueeze(-1).repeat(1,1,1,self._n_classes) #Size: n_batches x padded_document_length x padded_document_length x n_classes | |
#The scores (and gold labels) are flattened before using | |
#the binary cross entropy loss. | |
# We thus transform | |
flat_size = scores.size() | |
scores = scores*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes | |
scores_flat = scores.view(flat_size[0], flat_size[1], flat_size[2]*self._n_classes) #Size: n_batches x padded_document_length x (padded_document_length x n_classes) | |
labels = labels*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes | |
labels_flat = labels.view(flat_size[0], flat_size[1], flat_size[2]*self._n_classes) #Size: n_batches x padded_document_length x (padded_document_length x n_classes) | |
loss = self._loss_fn(scores_flat, labels_flat) | |
#Amplify the loss to actually see the figures move... | |
return 100*loss | |
@overrides | |
def forward(self, # type: ignore | |
text: Dict[str, torch.LongTensor], | |
relations: torch.IntTensor = None) -> Dict[str, torch.Tensor]: | |
# pylint: disable=arguments-differ | |
""" | |
Forward pass of the model. | |
Compute the predictions and the loss (if labels are available). | |
Parameters: | |
---------- | |
text: Dict[str, torch.LongTensor] | |
The input sentences which have transformed into indexes (integers) according to a mapping token:str -> token:int | |
relations: torch.IntTensor | |
The gold relations to predict. | |
""" | |
#Text field embedder map the token:int to their word embedding representation token:embedding (whatever these embeddings are). | |
text_embeddings = self._text_field_embedder(text) | |
#Compute the mask from the text: 1 if there is actually a word in the corresponding sentence, 0 if it has been padded. | |
mask = util.get_text_field_mask(text) #Size: batch_size x padded_document_length | |
#Compute the contextualized representation from the word embeddings. | |
#Usually, _context_layer is a Seq2seq model such as LSTM | |
encoded_text = self._context_layer(text_embeddings, mask) #Size: batch_size x padded_document_length x lstm_output_size | |
###### Relation Scorer ############## | |
#Compute the relation scores | |
left = torch.matmul(encoded_text, self._U) #Size: batch_size x padded_document_length x l | |
right = torch.matmul(encoded_text, self._W) #Size: batch_size x padded_document_length x l | |
left = left.permute(1,0,2) | |
left = left.unsqueeze(3) | |
right = right.permute(0,2,1) | |
right = right.unsqueeze(0) | |
B = left + right | |
B = B.permute(1,0,3,2) #Size: batch_size x padded_document_length x padded_document_length x l | |
outer_sum_bias = B + self._b #Size: batch_size x padded_document_length x padded_document_length x l | |
if self._activation == "relu": | |
activated_outer_sum_bias = F.relu(outer_sum_bias) | |
elif self._activation == "tanh": | |
activated_outer_sum_bias = F.tanh(outer_sum_bias) | |
relation_scores = torch.matmul(activated_outer_sum_bias, self._V) #Size: batch_size x padded_document_length x padded_document_length x n_classes | |
################################################################# | |
batch_size, padded_document_length = mask.size() | |
relation_sigmoid_scores = F.sigmoid(relation_scores) #Size: batch_size x padded_document_length x padded_document_length x n_classes | |
#predicted_relations[l, i, j, k] == 1 iif we predict a relation k with ARG1==i, ARG2==j in the l-th sentence of the batch | |
predicted_relations = torch.round(relation_sigmoid_scores) #Size: batch_size x padded_document_length x padded_document_length x n_classes | |
output_dict = { | |
"relation_sigmoid_scores": relation_sigmoid_scores, | |
"predicted_relations": predicted_relations, | |
"mask": mask | |
} | |
if relations is not None: | |
#Reformat the gold relations before computing the loss | |
#Size: batch_size x padded_document_length x padded_document_length x n_classes | |
#gold_relations[l, i, j, k] == 1 iif we predict a relation k with ARG1==i, ARG2==j in the l-th sentence of the batch | |
gold_relations = torch.zeros(batch_size, padded_document_length, padded_document_length, self._n_classes) | |
for exple_idx, exple_tags in enumerate(relations): #going through the batch | |
#rel is a list of list containing the current sentence in the batch | |
#each sublist in rel is of size padded_document_length | |
#and encodes a relation in the sentence where the two non zeros elements | |
#indicate the two words arguments AND the relation type between these two words. | |
for rel in exple_tags: | |
#relations have been padded, so for each sentence in the batch there are | |
#max_nb_of_relations_in_batch_for_one_sentence relations ie (number of sublist such as rel) | |
#The padded relations are simply list of size padded_document_length filled with 0. | |
if rel.sum().data[0]==0: continue | |
for idx in rel.nonzero(): | |
label_srt = self.vocab.get_token_from_index(rel[idx].data[0], self._label_namespace) | |
arg, rel_type = label_srt.split("_") | |
if arg == "ARG1": x = idx.data[0] | |
else: y = idx.data[0] | |
gold_relations[exple_idx, x, y, rel_type_2_idx[rel_type]] = 1 | |
#GPU support | |
if text_embeddings.is_cuda: gold_relations = gold_relations.cuda() | |
#Compute the loss | |
output_dict["loss"] = self._multi_class_cross_entropy_loss(scores = relation_scores, labels = gold_relations, mask = mask) | |
#Compute the metrics with the predictions. | |
self._relation_metric(predictions = predicted_relations, gold_labels = gold_relations, mask = mask) | |
return output_dict | |
@overrides | |
def decode(self, output_dict: Dict[str, torch.Tensor]): | |
""" | |
Decode the predictions | |
""" | |
decoded_predictions = [] | |
for instance_tags in output_dict["predicted_relations"]: | |
sentence_length = instance_tags.size(0) | |
decoded_relations = [] | |
for arg1, arg2, rel_type_idx in instance_tags.nonzero().data: | |
relation = ["*"]*sentence_length | |
rel_type = idx_2_rel_type[rel_type_idx] | |
relation[arg1] = "ARG1_" + rel_type | |
relation[arg2] = "ARG2_" + rel_type | |
decoded_relations.append(relation) | |
decoded_predictions.append(decoded_relations) | |
output_dict["decoded_predictions"] = decoded_predictions | |
return output_dict | |
@overrides | |
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | |
""" | |
Compute the metrics for relation: precision, recall and f1. | |
A relation is considered correct if we can correctly predict the last word of ARG1, the last word of ARG2 and the relation type. | |
""" | |
metric_dict = self._relation_metric.get_metric(reset = reset) | |
return {x: y for x, y in metric_dict.items() if "overall" in x} |
from typing import Dict, List, Optional, Set | |
from collections import defaultdict | |
import torch | |
from allennlp.common.checks import ConfigurationError | |
from allennlp.nn.util import get_lengths_from_binary_sequence_mask, ones_like | |
from allennlp.data.vocabulary import Vocabulary | |
from allennlp.training.metrics.metric import Metric | |
@Metric.register("relation_f1") | |
class RelationF1Measure(Metric): | |
""" | |
""" | |
def __init__(self) -> None: | |
""" | |
A class for computing the metrics specific to relation extraction. | |
We consider a relation correct if we correctly predict the last of the head of the two arguments and the relation type. | |
""" | |
#self._label_vocabulary = vocabulary.get_index_to_token_vocabulary(tag_namespace) | |
self._true_positives: int = 0 | |
self._false_positives: int = 0 | |
self._false_negatives: int = 0 | |
def __call__(self, | |
predictions: torch.Tensor, | |
gold_labels: torch.Tensor, | |
mask: Optional[torch.Tensor] = None): | |
""" | |
Update the TP, FP and FN counters. | |
Parameters | |
---------- | |
predictions : ``torch.Tensor``, required. | |
A tensor of predictions of shape (batch_size, sequence_length, num_classes). | |
gold_labels : ``torch.Tensor``, required. | |
A tensor of integer class label of shape (batch_size, sequence_length). It must be the same | |
shape as the ``predictions`` tensor without the ``num_classes`` dimension. | |
mask: ``torch.Tensor``, optional (default = None). | |
A masking tensor the same size as ``gold_labels``. | |
""" | |
if mask is None: | |
mask = ones_like(gold_labels) | |
# Get the data from the Variables. | |
predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, | |
gold_labels, | |
mask) | |
if (gold_labels.size() != predictions.size()): | |
raise ConfigurationError("Predictions and gold labels don't have the same size.") | |
#Apply mask | |
#Compute the mask before computing the loss | |
#Transform the mask that is at the sentence level (#Size: n_batches x padded_document_length) | |
#to a suitable format for the relation labels level | |
_, padded_document_length, _, n_classes = predictions.size() | |
mask = mask.float() | |
squared_mask = torch.stack([e.view(padded_document_length, 1)*e for e in mask], dim = 0) | |
squared_mask = squared_mask.unsqueeze(-1).repeat(1, 1, 1, n_classes) #Size: n_batches x padded_document_length x padded_document_length x n_classes | |
gold_labels = gold_labels.cpu() | |
predictions = predictions*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes | |
gold_labels = gold_labels*squared_mask #Size: n_batches x padded_document_length x padded_document_length x n_classes | |
# Iterate over timesteps in batch. | |
batch_size = gold_labels.size(0) | |
for i in range(batch_size): | |
flattened_predictions = predictions[i].view(-1).nonzero().cpu().numpy() | |
flattened_gold_labels = gold_labels[i].view(-1).nonzero().cpu().numpy() | |
for prediction in flattened_predictions: | |
if prediction in flattened_gold_labels: | |
self._true_positives += 1 | |
else: | |
self._false_positives += 1 | |
for gold in flattened_gold_labels: | |
if gold not in flattened_predictions: | |
self._false_negatives += 1 | |
def get_metric(self, reset: bool = False): | |
""" | |
Get the metrics and reset the counters if necessary. | |
""" | |
all_metrics = {} | |
# Compute the precision, recall and f1 for all spans jointly. | |
precision, recall, f1_measure = self._compute_metrics(self._true_positives, | |
self._false_positives, | |
self._false_negatives) | |
all_metrics["precision-overall"] = precision | |
all_metrics["recall-overall"] = recall | |
all_metrics["f1-measure-overall"] = f1_measure | |
if reset: | |
self.reset() | |
return all_metrics | |
@staticmethod | |
def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int): | |
precision = float(true_positives) / float(true_positives + false_positives + 1e-13) | |
recall = float(true_positives) / float(true_positives + false_negatives + 1e-13) | |
f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13)) | |
return precision, recall, f1_measure | |
def reset(self): | |
self._true_positives = 0 | |
self._false_positives = 0 | |
self._false_negatives = 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment