Skip to content

Instantly share code, notes, and snippets.

@ssabii
Last active September 25, 2020 10:28
Show Gist options
  • Save ssabii/8d83aeb018d981e44cf43a134c42a856 to your computer and use it in GitHub Desktop.
Save ssabii/8d83aeb018d981e44cf43a134c42a856 to your computer and use it in GitHub Desktop.
프림 알고리즘
import sys
import heapq
def prim(start, graphs, weights):
mst_graphs = [ [ 0 for _ in range(len(graphs)) ] for _ in range(len(graphs)) ] # 최소 신장 트리
mst_vertices = [-1 for _ in range(len(graphs))] # 최소 신장 트리 연결 정보 (-1로 초기화)
mst_weights = [ sys.maxsize for _ in range(len(graphs)) ] # 최소 신장 트리 가중치 정보 (최대 값으로 초기화)
visited = set() # 방문 정보
heap = [] # 최소 힙
heapq.heappush(heap, (0, start)) # (가중치, 정점) 쌍으로 힙에 저장
while len(heap) > 0:
vertex = heapq.heappop(heap)[1] # 힙에서 최소 가중치 꺼냄
visited.add(vertex)
for target, value in enumerate (graphs[vertex]): # 인접한 간선 순회
if target in visited:
continue
if value == 1 and weights[vertex][target] < mst_weights[target]:
heapq.heappush(heap, (weights[vertex][target], target))
mst_weights[target] = weights[vertex][target]
mst_vertices[target] = vertex
# 최소 신장 트리 생성
weight_sum = 0
for a, b in enumerate(mst_vertices):
if b >= 0:
mst_graphs[a][b] = 1
mst_graphs[b][a] = 1
weight_sum += mst_weights[a]
# 출력용
print("최소 가중치 합: ", weight_sum)
for row in mst_graphs:
print(row)
return mst_graphs
if __name__ == "__main__":
graphs = [
[0, 1, 0, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
[0, 1, 0, 1, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 1],
[0, 0, 0, 1, 0, 1, 1],
[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 1, 1, 0, 0],
]
weights = [
[0, 8, 0, 0, 0, 15, 0],
[8, 0, 17, 0, 0, 0, 10],
[0, 17, 0, 27, 0, 0, 0],
[0, 0, 27, 0, 29, 0, 25],
[0, 0, 0, 29, 0, 21, 23],
[15, 0, 0, 0, 21, 0, 0],
[0, 10, 0, 25, 23, 0, 0],
]
mst_graphs = prim(0, graphs, weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment