Skip to content

Instantly share code, notes, and snippets.

@viniru
Last active July 26, 2018 15:51
Show Gist options
  • Save viniru/7e0fc9b5fefd4925cf8438c9d88c4d96 to your computer and use it in GitHub Desktop.
Save viniru/7e0fc9b5fefd4925cf8438c9d88c4d96 to your computer and use it in GitHub Desktop.
Minimum spanning tree by using minumum heap
import java.util.HashMap;
import java.util.ArrayList;
import java.util.Scanner;
public class MST
{
HashMap<Node,ArrayList<Pair>> al; //adjacency list to store the graph
ArrayList<Node> map; //map a vertex to its respective vertex object
Heap heap;
int nodes ;
int edges ;
public static void main(String args[])
{
Scanner sc = new Scanner(System.in);
MST mst = new MST();
mst.heap = new Heap(mst);
mst.nodes = sc.nextInt();
mst.edges = sc.nextInt();
mst.map = new ArrayList<>(mst.nodes+1); //create all the required nodes
for(int i=0;i<=mst.nodes;i++)
mst.map.add(new Node(i));
mst.al = new HashMap<>();
for(int i =1;i<=mst.nodes;i++)
mst.al.put(mst.map.get(i),new ArrayList<Pair>());
int temp = mst.edges;
while(temp-->0)
{
int a = sc.nextInt();
int b = sc.nextInt();
int weight = sc.nextInt();
Pair p1 = new Pair(mst.map.get(b),weight);
Pair p2 = new Pair(mst.map.get(a),weight);
mst.al.get(mst.map.get(a)).add(p1);
mst.al.get(mst.map.get(b)).add(p2);
}
mst.heap.buildHeap();
mst.start();
mst.compute();
}
void compute()
{
long sum = 0;
for(int i=1;i<=this.nodes;i++)
sum+=this.map.get(i).key;
System.out.println(sum);
}
void start()
{
Node x=this.heap.heap.get(0);
x.visited = true;
x.key = 0;
while(this.heap.size() > 0)
{
System.out.println(this.heap.size());
Node n = this.heap.extractMin();
n.visited = true;
for(Pair p : this.al.get(n))
{
if(p.vertex.visited == true)
continue;
if(p.vertex.key > p.weight)
{
p.vertex.key = p.weight;
this.heap.resetPos(p.vertex);
}
}
}
}
}
class Heap{
MST mst ;
ArrayList<Node> heap;
int size()
{
return heap.size();
}
void print()
{
for(Node i : heap)
System.out.print(i.vertex+" ");
}
Heap(MST m)
{
mst = m;
heap = new ArrayList<>();
}
void buildHeap()
{
for(int i=1;i<=mst.nodes;i++)
{
Node n = mst.map.get(i);
heap.add(n);
n.loc = i-1;
}
}
void resetPos(Node n)
{
int pos = n.loc;
int parent = (int)Math.ceil((float)pos/2)-1;
while(parent >= 0)
{
Node par = heap.get(parent);
if(par.key > n.key)
{
swap(n,par);
pos = parent;
n = heap.get(pos);
parent = (int)Math.ceil((float)parent/2)-1;
}
else break;
}
heapify(heap.get(pos));
}
void heapify(Node n)
{
int pos = n.loc;
int size = heap.size();
int bound = size/2;
int low = pos;
Node rchild=n; //just to initialize with something
Node lchild;
int rchildpos;
while(pos < bound)
{
n = heap.get(pos);
rchildpos = pos*2+2;
if(rchildpos < size )
rchild = heap.get(rchildpos);
lchild = heap.get(2*pos+1);
if(rchildpos < size && rchild.key < n.key)
low = rchildpos;
if(lchild.key < heap.get(low).key)
low = 2*pos+1;
if(pos == low)
break;
pos = low;
swap(heap.get(low),n);
}
}
void swap(Node x, Node y)
{
int xloc = x.loc;
int yloc = y.loc;
heap.set(xloc,y);
heap.set(yloc,x);
x.loc = yloc;
y.loc = xloc;
}
Node extractMin()
{
Node n = heap.get(0);
delete(n);
return n;
}
void delete(Node n)
{
Node t = heap.get(heap.size()-1);
swap(n,t);
heap.remove(heap.size()-1);
if(heap.size() != 0)
if(heap.size() != 0)
heapify(heap.get(0));
}
}
class Pair
{
Node vertex;
int weight;
Pair(Node v,int w)
{
vertex = v;
weight = w;
}
}
class Node
{
boolean visited=false;
int loc = -1;
int key = Integer.MAX_VALUE;
int vertex ;
Node(int vertex)
{
this.vertex = vertex;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment