Skip to content

Instantly share code, notes, and snippets.

@mwestwood
Created May 4, 2023 15:00
Show Gist options
  • Save mwestwood/a2bf4cdf7f1a99716f7b640034d9f2e9 to your computer and use it in GitHub Desktop.
Save mwestwood/a2bf4cdf7f1a99716f7b640034d9f2e9 to your computer and use it in GitHub Desktop.
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
# Create a sample dataframe
df = pd.DataFrame({
'ID': [1, 2, 3, 4],
'Col1': ['apple', 'orange', 'banana', 'pear'],
'Col2': [['apple', 'pear'], ['banana'], ['apple', 'orange'], ['banana', 'pear']]
})
# Create a dictionary to store the connections
connections = {}
# Loop through the rows of the dataframe
for i, row in df.iterrows():
# Loop through the values in the 'Col2' column
for value in row['Col2']:
# Check if any of the values match the 'Col1' column
if value in list(df['Col1']):
# Get the indices of the matching rows
indices = df.index[df['Col1'] == value].tolist()
# Add the connections to the dictionary
if i not in connections:
connections[i] = set()
connections[i].update(indices)
# Create a directed graph
G = nx.DiGraph()
# Add nodes to the graph
for i in range(len(df)):
G.add_node(i)
# Add edges to the graph based on the connections
for key, values in connections.items():
for value in values:
G.add_edge(key, value)
# Plot the graph
pos = nx.spring_layout(G)
nx.draw_networkx(G, pos, with_labels=True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment