Skip to content

Instantly share code, notes, and snippets.

@drussellmrichie
Last active February 6, 2021 12:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save drussellmrichie/47deb429350e2e99ffb3272ab6ab216a to your computer and use it in GitHub Desktop.
Save drussellmrichie/47deb429350e2e99ffb3272ab6ab216a to your computer and use it in GitHub Desktop.
Computes the average dependency parse tree height for a sequence of sentences
import spacy
import numpy as np
nlp = spacy.load('en', disable=['ner'])
def tree_height(root):
"""
Find the maximum depth (height) of the dependency parse of a spacy sentence by starting with its root
Code adapted from https://stackoverflow.com/questions/35920826/how-to-find-height-for-non-binary-tree
:param root: spacy.tokens.token.Token
:return: int, maximum height of sentence's dependency parse tree
"""
if not list(root.children):
return 1
else:
return 1 + max(tree_height(x) for x in root.children)
def get_average_heights(paragraph):
"""
Computes average height of parse trees for each sentence in paragraph.
:param paragraph: spacy doc object or str
:return: float
"""
if type(paragraph) == str:
doc = nlp(paragraph)
else:
doc = paragraph
roots = [sent.root for sent in doc.sents]
return np.mean([tree_height(root) for root in roots])
def test_average_height_func():
paragraph = (u"Autonomous cars shift insurance liability toward manufacturers. Consumers rejoice."
" Manufacturers complain a lot. Suffer.")
paragraph2 = (u"The cat on the hot tin roof of my parent's house meowed.")
paragraph3 = (u"The cat on the hot tin roof meowed at my parent's house.")
print(paragraph)
print(get_average_heights(paragraph))
print(paragraph2)
print(get_average_heights(paragraph2))
print(paragraph3)
print(get_average_heights(paragraph3))
if __name__ == "__main__":
test_average_height_func()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment