Skip to content

Instantly share code, notes, and snippets.

@MrNocTV
Created July 7, 2017 13:52
Show Gist options
  • Save MrNocTV/1cc91e105310bbb18129b7f2e62fe333 to your computer and use it in GitHub Desktop.
Save MrNocTV/1cc91e105310bbb18129b7f2e62fe333 to your computer and use it in GitHub Desktop.
Prim algorithm
'''
Tutorial:
https://www.youtube.com/watch?v=oP2-8ysT3QQ
'''
import bisect
class Graph:
def __init__(self, start, vertexes, edges):
self._vertexes = vertexes
self._start = start
self._edges = edges
self._min_edges = dict.fromkeys(vertexes, None)
self._prim_table = dict.fromkeys(vertexes, float('Inf'))
def PrimMST(self):
min_vertex = self._start
del self._prim_table[min_vertex]
while self._prim_table:
new_min_val = float('Inf')
new_min_vertex = None
for child in self._vertexes[min_vertex]:
if child in self._prim_table:
if self._prim_table[child] > self._edges[(min_vertex, child)]:
self._prim_table[child] = self._edges[(min_vertex, child)]
self._min_edges[child] = (min_vertex, child, self._prim_table[child])
min_vertex = new_min_vertex
if min_vertex is None:
new_min_val = min(self._prim_table.values())
for v in self._prim_table:
if self._prim_table[v] == new_min_val:
min_vertex = v
break
del self._prim_table[min_vertex]
min_length = 0
for key in self._min_edges:
if self._min_edges[key] is not None:
u,v,length = self._min_edges[key]
min_length += length
print(min_length)
if __name__ == '__main__':
# vertexes = {
# 'A': ['D', 'B'],
# 'B': ['A', 'D', 'C'],
# 'C': ['B', 'D', 'E', 'F'],
# 'D': ['A', 'B', 'C', 'E'],
# 'E': ['D', 'C', 'F'],
# 'F': ['C', 'E']
# }
# edges = {
# ('A', 'D'): 1,
# ('D', 'A'): 1,
# ('A', 'B'): 3,
# ('B', 'A'): 3,
# ('B', 'D'): 3,
# ('D', 'B'): 3,
# ('B', 'C'): 1,
# ('C', 'B'): 1,
# ('C', 'D'): 1,
# ('D', 'C'): 1,
# ('D', 'E'): 6,
# ('E', 'D'): 6,
# ('C', 'E'): 5,
# ('E', 'C'): 5,
# ('C', 'F'): 4,
# ('F', 'C'): 4,
# ('E', 'F'): 2,
# ('F', 'E'): 2
# }
# start = 'A'
# g = Graph(start, vertexes, edges)
# g.PrimMST()
from sys import stdin, stdout
from collections import defaultdict
n, m = [int(x) for x in stdin.readline().split()]
vertexes = defaultdict(list)
edges = dict()
for i in range(m):
u, v, length = [int(x) for x in stdin.readline().split()]
if (u, v) in edges:
if edges[(u, v)] > length:
edges[(u,v)] = length
edges[(v,u)] = length
else:
vertexes[u].append(v)
vertexes[v].append(u)
edges[(u,v)] = length
edges[(v,u)] = length
start = int(stdin.readline())
g = Graph(start, vertexes, edges)
g.PrimMST()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment