Skip to content

Instantly share code, notes, and snippets.

@oliver-batey
Last active November 28, 2021 19:55
Show Gist options
  • Save oliver-batey/ba9073ac9ebb2fcb391d4213b89efe40 to your computer and use it in GitHub Desktop.
Save oliver-batey/ba9073ac9ebb2fcb391d4213b89efe40 to your computer and use it in GitHub Desktop.
Calculate the distance between nodes of dependency network
from itertools import count
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import textacy
with open("news_article.txt", "r") as file:
data = file.read().replace("\n", "")
article = data.replace(u"\xa0", u" ")
# Turn sentence into a Doc object
sent = "Stocks Hit Record as Biden Calls for More Stimulus"
doc = textacy.make_spacy_doc(sent, lang="en_core_web_sm")
for t in doc:
print(t.text, t.pos_)
# Build list of nodes with number and attributes
nodes = []
for token in doc:
nodes.append(
(
token.i,
{
"text": token.text,
"idx": token.i,
"pos": token.pos_,
"tag": token.tag_,
"dep": token.dep_,
},
)
)
# Construct edges from tokens to children
edges = []
for token in doc:
for child in token.children:
edges.append((token.i, child.i))
# Add nodes and edges to graph
G = nx.Graph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)
# Get nodes which are nominal subjects and nodes which are nouns or proper nouns
subjects = [(x, y) for x, y in G.nodes(data=True) if y["dep"] == "nsubj"]
noun_nodes = [
(x, y) for x, y in G.nodes(data=True) if y["pos"] == "NOUN" or y["pos"] == "PROPN"
][1:]
# Calculate shortest distance between nominal subjects and nouns
dependencies = []
for (subj, node1) in subjects:
for (noun, node2) in noun_nodes:
d = nx.shortest_path(G, source=subj, target=noun)
dictionary = {"subject": None, "noun": None, "distance": None, "path": None}
dictionary["subject"] = node1["text"]
dictionary["noun"] = node2["text"]
dictionary["distance"] = len(d)
dictionary["path"] = d
dependencies.append(dictionary)
# Save results in a dataframe
data = pd.DataFrame(dependencies)
data.to_csv("dependency_data_full_doc.csv")
bc = nx.betweenness_centrality(G)
node_sizes = 1e5 * np.array(list(bc.values())) + 5
# Create number for each group to allow use of colormap
# Get unique groups
groups = set(nx.get_node_attributes(G, "pos").values())
mapping = dict(zip(sorted(groups), count()))
nodes = G.nodes()
colors = [mapping[G.nodes[n]["pos"]] for n in nodes]
print(mapping)
# plot and save the graph
fig = plt.figure(figsize=(12, 6))
labels_txt = nx.get_node_attributes(G, "text")
labels_pos = nx.get_node_attributes(G, "pos")
keys = []
for key, tag in labels_pos.items():
if tag == "NOUN":
keys.append(key)
labels = {key: labels_txt[key] for key in keys}
pos = nx.spring_layout(G)
ec = nx.draw_networkx_edges(G, pos, edge_color="lightgrey")
nc = nx.draw_networkx_nodes(
G,
pos,
nodelist=nodes,
node_color=colors,
node_size=250,
cmap=plt.cm.Pastel2,
alpha=0.8,
)
labels = nx.draw_networkx_labels(G, pos, labels_txt, font_size=12, font_color="dimgrey")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment