Skip to content

Instantly share code, notes, and snippets.

@pafonta
Last active November 23, 2021 10:33
Show Gist options
  • Save pafonta/21f3db4d9c31f6a1c2f7ede8cbf3406b to your computer and use it in GitHub Desktop.
Save pafonta/21f3db4d9c31f6a1c2f7ede8cbf3406b to your computer and use it in GitHub Desktop.
# For an example of use, see https://gist.github.com/pafonta/21f3db4d9c31f6a1c2f7ede8cbf3406b#gistcomment-3970844.
"""Entity Linking - Link mentions from texts to terms in ontologies.
Use character-based embedding to handle plurals, misspellings, partial matches, ...
"""
import pickle
import faiss
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
class Candidate:
"""Represent a candidate in the configured ontology for a text mention."""
def __init__(self, distance, alias, uid, concept, definition):
self.distance = distance
self.alias = alias
self.uid = uid
self.concept = concept
self.definition = definition
def __repr__(self):
"""Display a 'Candidate' instance."""
attrs = (f"{k}={v!r}" for k, v in self.__dict__.items())
return f"Candidate({', '.join(attrs)})"
class EntityLinker:
"""Perform entity linking on text mentions against an ontology."""
def __init__(self, bulk):
self.bulk = bulk
self.ontology = None
self.aliases = None
self.model = None
self.index = None
def link(self, mentions, threshold=0.8):
"""Link text mentions to the configured ontology."""
selections = self.candidates(mentions, 3)
return [self.disambiguate(cs, m, None, threshold) for m, cs in selections]
def disambiguate(self, candidates, mention, context, threshold):
"""Identify which candidate the text mention actually refers to."""
# TODO Use the 'mention' and the 'context' to improve disambiguation.
zeros = [x for x in candidates if x.distance == 0]
if zeros:
chosen = sorted(zeros, key=lambda x: len(x.concept))[0]
return chosen
else:
chosen = sorted(candidates, key=lambda x: x.distance)[0]
return chosen if chosen.distance <= threshold else None
def candidates(self, mentions, limit):
"""Select candidates in the configured ontology for text mentions."""
def _(d, i):
alias, uid = self.aliases[int(i)]
return Candidate(d, alias, uid, *self.ontology[uid])
embeddings = self.model.transform(mentions)
if self.bulk:
distances, indexes = self.index.search(embeddings.toarray(), limit)
else:
distances, indexes = self.index.kneighbors(embeddings, limit)
results = np.stack((distances, indexes), axis=2)
return [(m, [_(d, i) for d, i in rs]) for m, rs in zip(mentions, results)]
def train(self, ontology, model_params, index_params):
"""Train an entity linking model. Build an EntityLinker instance."""
self.ontology = {k: (v[0], v[2]) for k, v in ontology.items()}
self.model = TfidfVectorizer(**model_params)
aliases = [(x, k) for k, v in ontology.items() for x in [v[0], *v[1]]]
embeddings = self.model.fit_transform(x for x, _ in aliases)
flags = np.array(embeddings.sum(axis=1) != 0).reshape(-1)
filtered_embeddings = embeddings[flags]
self.aliases = [t for t, f in zip(aliases, flags) if f]
if self.bulk:
self.index = faiss.IndexFlatL2(filtered_embeddings.shape[1])
self.index.add(filtered_embeddings.toarray())
else:
self.index = NearestNeighbors(**index_params)
self.index.fit(filtered_embeddings)
self._stats()
def save_pretrained(self, dirpath):
"""Save a trained 'EntityLinker' instance."""
with open(f"{dirpath}/model", "wb") as f:
pickle.dump(self.ontology, f)
pickle.dump(self.aliases, f)
pickle.dump(self.model, f)
if not self.bulk:
pickle.dump(self.index, f)
if self.bulk:
faiss.write_index(self.index, f"{dirpath}/index")
@staticmethod
def from_pretrained(dirpath, bulk):
"""Load a trained 'EntityLinker' instance."""
linker = EntityLinker(bulk)
with open(f"{dirpath}/model", "rb") as f:
linker.ontology = pickle.load(f)
linker.aliases = pickle.load(f)
linker.model = pickle.load(f)
if not bulk:
linker.index = pickle.load(f)
if bulk:
linker.index = faiss.read_index(f"{dirpath}/index")
linker._stats()
return linker
def _stats(self):
"""Print statistics on the 'EntityLinker' instance."""
ccount = len(self.ontology)
tcount = len(self.aliases)
print(f"INFO EntityLinker Links to {ccount} concepts ({tcount} aliases).")
@pafonta
Copy link
Author

pafonta commented Nov 22, 2021

Use

Get some entity mentions from texts:

mentions = ["autophagy", "cholera toxin"]

Load a pretrained model:

linker = EntityLinker.from_pretrained("entity-linker/", bulk=False)

Use the model to link mentions to ontology concepts:

concepts = linker.link(text_mentions)
print(concepts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment