Last active
November 23, 2021 10:33
-
-
Save pafonta/21f3db4d9c31f6a1c2f7ede8cbf3406b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For an example of use, see https://gist.github.com/pafonta/21f3db4d9c31f6a1c2f7ede8cbf3406b#gistcomment-3970844. | |
"""Entity Linking - Link mentions from texts to terms in ontologies. | |
Use character-based embedding to handle plurals, misspellings, partial matches, ... | |
""" | |
import pickle | |
import faiss | |
import numpy as np | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.neighbors import NearestNeighbors | |
class Candidate: | |
"""Represent a candidate in the configured ontology for a text mention.""" | |
def __init__(self, distance, alias, uid, concept, definition): | |
self.distance = distance | |
self.alias = alias | |
self.uid = uid | |
self.concept = concept | |
self.definition = definition | |
def __repr__(self): | |
"""Display a 'Candidate' instance.""" | |
attrs = (f"{k}={v!r}" for k, v in self.__dict__.items()) | |
return f"Candidate({', '.join(attrs)})" | |
class EntityLinker: | |
"""Perform entity linking on text mentions against an ontology.""" | |
def __init__(self, bulk): | |
self.bulk = bulk | |
self.ontology = None | |
self.aliases = None | |
self.model = None | |
self.index = None | |
def link(self, mentions, threshold=0.8): | |
"""Link text mentions to the configured ontology.""" | |
selections = self.candidates(mentions, 3) | |
return [self.disambiguate(cs, m, None, threshold) for m, cs in selections] | |
def disambiguate(self, candidates, mention, context, threshold): | |
"""Identify which candidate the text mention actually refers to.""" | |
# TODO Use the 'mention' and the 'context' to improve disambiguation. | |
zeros = [x for x in candidates if x.distance == 0] | |
if zeros: | |
chosen = sorted(zeros, key=lambda x: len(x.concept))[0] | |
return chosen | |
else: | |
chosen = sorted(candidates, key=lambda x: x.distance)[0] | |
return chosen if chosen.distance <= threshold else None | |
def candidates(self, mentions, limit): | |
"""Select candidates in the configured ontology for text mentions.""" | |
def _(d, i): | |
alias, uid = self.aliases[int(i)] | |
return Candidate(d, alias, uid, *self.ontology[uid]) | |
embeddings = self.model.transform(mentions) | |
if self.bulk: | |
distances, indexes = self.index.search(embeddings.toarray(), limit) | |
else: | |
distances, indexes = self.index.kneighbors(embeddings, limit) | |
results = np.stack((distances, indexes), axis=2) | |
return [(m, [_(d, i) for d, i in rs]) for m, rs in zip(mentions, results)] | |
def train(self, ontology, model_params, index_params): | |
"""Train an entity linking model. Build an EntityLinker instance.""" | |
self.ontology = {k: (v[0], v[2]) for k, v in ontology.items()} | |
self.model = TfidfVectorizer(**model_params) | |
aliases = [(x, k) for k, v in ontology.items() for x in [v[0], *v[1]]] | |
embeddings = self.model.fit_transform(x for x, _ in aliases) | |
flags = np.array(embeddings.sum(axis=1) != 0).reshape(-1) | |
filtered_embeddings = embeddings[flags] | |
self.aliases = [t for t, f in zip(aliases, flags) if f] | |
if self.bulk: | |
self.index = faiss.IndexFlatL2(filtered_embeddings.shape[1]) | |
self.index.add(filtered_embeddings.toarray()) | |
else: | |
self.index = NearestNeighbors(**index_params) | |
self.index.fit(filtered_embeddings) | |
self._stats() | |
def save_pretrained(self, dirpath): | |
"""Save a trained 'EntityLinker' instance.""" | |
with open(f"{dirpath}/model", "wb") as f: | |
pickle.dump(self.ontology, f) | |
pickle.dump(self.aliases, f) | |
pickle.dump(self.model, f) | |
if not self.bulk: | |
pickle.dump(self.index, f) | |
if self.bulk: | |
faiss.write_index(self.index, f"{dirpath}/index") | |
@staticmethod | |
def from_pretrained(dirpath, bulk): | |
"""Load a trained 'EntityLinker' instance.""" | |
linker = EntityLinker(bulk) | |
with open(f"{dirpath}/model", "rb") as f: | |
linker.ontology = pickle.load(f) | |
linker.aliases = pickle.load(f) | |
linker.model = pickle.load(f) | |
if not bulk: | |
linker.index = pickle.load(f) | |
if bulk: | |
linker.index = faiss.read_index(f"{dirpath}/index") | |
linker._stats() | |
return linker | |
def _stats(self): | |
"""Print statistics on the 'EntityLinker' instance.""" | |
ccount = len(self.ontology) | |
tcount = len(self.aliases) | |
print(f"INFO EntityLinker Links to {ccount} concepts ({tcount} aliases).") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Use
Get some entity mentions from texts:
Load a pretrained model:
Use the model to link mentions to ontology concepts: