Skip to content

Instantly share code, notes, and snippets.

@timini
Created June 26, 2018 14:58
Show Gist options
  • Save timini/965a5917b741b4c4fba93dbc4ea55f58 to your computer and use it in GitHub Desktop.
Save timini/965a5917b741b4c4fba93dbc4ea55f58 to your computer and use it in GitHub Desktop.
import numpy as np
from pprint import pprint
class Graph:
def __init__(self, directed=False):
self._adj_matrix = np.zeros([0,0])
self._nodes = []
self._directed = directed
def debug(self):
pprint(self._nodes)
pprint(self._adj_matrix)
def add_node(self, N):
if not self.has_node(N):
self._nodes.append(N)
self._adj_matrix = np.pad(self._adj_matrix, (0,1), mode='constant')
def has_node(self, N):
return N in self._nodes
def has_edge(self, A, B):
A_index = self._nodes.index(A)
B_index = self._nodes.index(B)
return self._adj_matrix[A_index][B_index] == 1
def add_edge(self, A, B):
A_index = self._nodes.index(A)
B_index = self._nodes.index(B)
self._adj_matrix[A_index][B_index] = 1
if not self._directed:
self._adj_matrix[B_index][A_index] = 1
def remove_edge(self, A, B):
A_index = self._nodes.index(A)
B_index = self._nodes.index(B)
self._adj_matrix[A_index][B_index] = 0
def get_connected_nodes(self, N, visited=None):
if visited is None:
visited = []
if N in visited:
return set([])
visited.extend(N)
try:
current_node_index = self._nodes.index(N)
# if a node is not in the graph then it has no connections
except:
return set([])
connected_nodes = set([])
for i in range(0, len(self._nodes)):
if self._adj_matrix[current_node_index][i] == 1:
connected_nodes.add(self._nodes[i])
connected_nodes = connected_nodes.union(self.get_connected_nodes(self._nodes[i], visited=visited))
return set(connected_nodes)
def is_connected(self, A, B, visited=None):
if visited is None:
visited = []
# get the indicies of the nodes
try:
A_index = self._nodes.index(A)
B_index = self._nodes.index(B)
# if the node is not in the node list then it is not connected!
except:
return False
if A in visited:
return self._adj_matrix[A_index][B_index] == 1
visited.extend(A)
for i in range(0, len(self._nodes)):
if self._adj_matrix[A_index][i] == 1:
if i == B_index:
return True
if self.is_connected(self._nodes[i], B, visited=visited):
return True
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment