Skip to content

Instantly share code, notes, and snippets.

@cire3791
Created December 8, 2012 10:28
Show Gist options
  • Save cire3791/4239718 to your computer and use it in GitHub Desktop.
Save cire3791/4239718 to your computer and use it in GitHub Desktop.
Merge sort on a (doubly) linked list
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <queue>
#include <chrono>
// merge sort example
class list_info ;
class node ;
typedef node* (*sort_func)(node*) ;
node* mergesort(node*) ; // iterative, bottom up
node* mergesort_recursive(node*) ; // recursive, top down.
node* merge_sorted(node*, node*) ; // merges two sorted lists. Used by both versions of the sort.
node* split( list_info ) ; // recursively splits the list.
// utility functions
void print(node*) ;
node* make_list( unsigned) ;
void destroy(node*) ;
void do_sort(sort_func) ;
void check_sorted(node*) ;
int main()
{
auto seed = std::time(0) ;
std::cout << seed << '\n' ;
std::srand(seed) ;
do_sort(mergesort) ;
do_sort(mergesort_recursive) ;
}
struct node
{
int value ;
node* next ;
node* prev ;
node(int v) : value(v), next(nullptr), prev(nullptr) {}
node* truncate()
{
node * cutoff = next ;
if ( cutoff)
cutoff->prev = next = nullptr ;
return cutoff ;
}
void insert(node* n, bool before=false)
{
if ( before )
{
if ( n->prev )
n->prev->next = this ;
next = n ;
prev = n->prev ;
n->prev = this ;
}
else
{
if ( n->next )
n->next->prev = this ;
prev = n ;
next = n->next ;
n->next = this ;
}
}
};
struct list_info
{
node * head, *tail ;
void append(node*n)
{
n->insert(tail) ;
tail = n ;
}
void append_list(node*n)
{
tail->next = n ;
n->prev = tail ;
while ( tail->next )
tail = tail->next ;
}
};
// combine two sorted lists
node* merge_sorted( node* a, node* b)
{
list_info li ;
if ( a->value < b->value )
{
li.head = a ;
a = a->next ;
}
else
{
li.head = b ;
b = b->next ;
}
li.head->truncate() ;
li.tail = li.head ;
while ( a && b )
{
node * merge ;
if ( a->value < b->value )
{
merge = a ;
a = a->next ;
}
else
{
merge = b ;
b = b->next ;
}
li.append(merge) ;
}
li.append_list(a ? a : b) ;
return li.head ;
}
// the sort driver
node* mergesort(node* list)
{
node* current = list ;
std::queue<node*> q ;
while ( current )
{
q.push(current) ;
// take advantage of already ordered
// sequences to reduce the number of merges.
int last_value = current->value ;
if ( current->next )
{
node * n = current->next ;
while ( n && last_value <= n->value)
{
last_value = n->value ;
n = n->next ;
}
current = n ? n->prev->truncate() : nullptr ;
}
else
current = nullptr ;
}
//std::cout << "lists to sort: " << q.size() << '\n' ;
if ( q.size() > 1 ) // q.size() == 1 if the list was already entirely sorted.
{
while ( !q.empty() )
{
node * a = q.front() ; q.pop() ;
node * b = q.front() ; q.pop() ;
list = merge_sorted(a,b) ;
if ( !q.empty() )
q.push(list) ;
}
}
else
list = q.front() ;
return list ;
}
node* split( list_info l, unsigned nodes )
{
if ( nodes == 1 )
return l.head ;
// don't split lists all the way down to one element if it can be avoided
if ( nodes == 2 )
{
if ( l.head->value < l.tail->value )
return l.head ;
l.head->insert(l.tail) ;
l.head = l.tail ;
l.tail = l.head->next ;
return l.head ;
}
// set up the lists to be split.
node* middle = l.head ;
for ( unsigned i=0; i<nodes/2; ++i )
middle = middle->next ;
list_info l2 ;
l2.tail = l.tail ;
l.tail = middle->prev ;
l2.head = l.tail->truncate() ;
return merge_sorted( split(l, nodes/2), split(l2, nodes-nodes/2) ) ;
}
node* mergesort_recursive( node * list )
{
list_info l ;
l.head = list ;
l.tail = l.head ;
unsigned nodes = 1 ;
while ( l.tail->next )
{
l.tail = l.tail->next ;
++nodes ;
}
return split(l, nodes) ;
}
void do_sort( sort_func sort )
{
node * list = make_list(1000000) ;
std::chrono::high_resolution_clock clock ;
auto start = clock.now() ;
list = sort(list) ;
auto end = clock.now() ;
check_sorted(list) ;
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end-start).count() << "ms sorting the list.\n" ;
destroy(list) ;
}
node* make_list(unsigned size)
{
node * head = new node(std::rand()%5000) ;
node * current = head ;
for ( unsigned i=0; i<size-1; ++i)
{
(new node(std::rand()%5000))->insert(current) ;
current = current->next ;
}
return head ;
}
void destroy(node* list)
{
node * current = list ;
while ( list )
{
node * d = list ;
list = list->next ;
delete d ;
}
}
void check_sorted(node* list)
{
list = list->next ;
while ( list )
{
if ( list->value < list->prev->value )
{
std::cout << "List is not sorted.\n" ;
return ;
}
list = list->next ;
}
std::cout << "List is sorted.\n" ;
}
void print(node* n)
{
while ( n )
{
std::cout << n->value << ' ' ;
n=n->next ;
}
std::cout << '\n' ;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment