Last active
January 4, 2023 02:33
-
-
Save Christopher-Thornton/87a1d40b65e075a5b53b52deb5d50722 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import re | |
import spacy | |
import neuralcoref | |
nlp = spacy.load('en_core_web_lg') | |
neuralcoref.add_to_pipe(nlp) | |
def get_entity_pairs(text, coref=True): | |
# preprocess text | |
text = re.sub(r'\n+', '.', text) # replace multiple newlines with period | |
text = re.sub(r'\[\d+\]', ' ', text) # remove reference numbers | |
text = nlp(text) | |
if coref: | |
text = nlp(text._.coref_resolved) # resolve coreference clusters | |
def refine_ent(ent, sent): | |
unwanted_tokens = ( | |
'PRON', # pronouns | |
'PART', # particle | |
'DET', # determiner | |
'SCONJ', # subordinating conjunction | |
'PUNCT', # punctuation | |
'SYM', # symbol | |
'X', # other | |
) | |
ent_type = ent.ent_type_ # get entity type | |
if ent_type == '': | |
ent_type = 'NOUN_CHUNK' | |
ent = ' '.join(str(t.text) for t in | |
nlp(str(ent)) if t.pos_ | |
not in unwanted_tokens and t.is_stop == False) | |
elif ent_type in ('NOMINAL', 'CARDINAL', 'ORDINAL') and str(ent).find(' ') == -1: | |
refined = '' | |
for i in range(len(sent) - ent.i): | |
if ent.nbor(i).pos_ not in ('VERB', 'PUNCT'): | |
refined += ' ' + str(ent.nbor(i)) | |
else: | |
ent = refined.strip() | |
break | |
return ent, ent_type | |
sentences = [sent.string.strip() for sent in text.sents] # split text into sentences | |
ent_pairs = [] | |
for sent in sentences: | |
sent = nlp(sent) | |
spans = list(sent.ents) + list(sent.noun_chunks) # collect nodes | |
spans = spacy.util.filter_spans(spans) | |
with sent.retokenize() as retokenizer: | |
[retokenizer.merge(span, attrs={'tag': span.root.tag, | |
'dep': span.root.dep}) for span in spans] | |
deps = [token.dep_ for token in sent] | |
# limit our example to simple sentences with one subject and object | |
if (deps.count('obj') + deps.count('dobj')) != 1\ | |
or (deps.count('subj') + deps.count('nsubj')) != 1: | |
continue | |
for token in sent: | |
if token.dep_ not in ('obj', 'dobj'): # identify object nodes | |
continue | |
subject = [w for w in token.head.lefts if w.dep_ | |
in ('subj', 'nsubj')] # identify subject nodes | |
if subject: | |
subject = subject[0] | |
# identify relationship by root dependency | |
relation = [w for w in token.ancestors if w.dep_ == 'ROOT'] | |
if relation: | |
relation = relation[0] | |
# add adposition or particle to relationship | |
if relation.nbor(1).pos_ in ('ADP', 'PART'): | |
relation = ' '.join((str(relation), str(relation.nbor(1)))) | |
else: | |
relation = 'unknown' | |
subject, subject_type = refine_ent(subject, sent) | |
token, object_type = refine_ent(token, sent) | |
ent_pairs.append([str(subject), str(relation), str(token), | |
str(subject_type), str(object_type)]) | |
ent_pairs = [sublist for sublist in ent_pairs | |
if not any(str(ent) == '' for ent in sublist)] | |
pairs = pd.DataFrame(ent_pairs, columns=['subject', 'relation', 'object', | |
'subject_type', 'object_type']) | |
print('Entity pairs extracted:', str(len(ent_pairs))) | |
return pairs | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment