Skip to content

Instantly share code, notes, and snippets.

@zachwhaley
Last active March 13, 2024 09:06
Show Gist options
  • Save zachwhaley/533235d7be027b12c67ebff478678aa5 to your computer and use it in GitHub Desktop.
Save zachwhaley/533235d7be027b12c67ebff478678aa5 to your computer and use it in GitHub Desktop.
C++ quicksort with iterators
#include <algorithm>
#include <iostream>
#include <list>
template<typename Iter>
void print(const Iter& beg, const Iter& end)
{
std::for_each(beg, end, [](auto& i) { std::cout << i << " "; });
std::cout << std::endl;
}
template<typename Iter>
Iter part(Iter beg, Iter end, const typename Iter::value_type& pivot)
{
Iter head = beg;
Iter tail = std::prev(end);
while (head != tail) {
while (*head < pivot) {
if (++head == tail) {
return head;
}
}
while (*tail >= pivot) {
if (--tail == head) {
return head;
}
}
std::cout << "swap " << *head << " ↔ " << *tail << std::endl;
std::iter_swap(head, tail);
std::cout << " ⇒ "; print(beg, end);
if (++head == tail--) {
return head;
}
}
return head;
}
template<typename Iter>
void quick_sort(Iter beg, Iter end)
{
std::cout << "sort "; print(beg, end);
if (beg == end) {
return;
}
auto pivot = *beg;
Iter split = part(beg, end, pivot);
// sort left
quick_sort(beg, split);
// sort right
Iter new_middle = beg;
quick_sort(++new_middle, end);
}
int main(int argc, const char* argv[])
{
std::list<int> l = { 3, 2, 1, 5, 2, 12, 4356, 1, 5, 12 };
quick_sort(l.begin(), l.end());
std::cout << "done "; print(l.begin(), l.end());
return 0;
}
@andrascii
Copy link

andrascii commented Jun 17, 2019

First of all your partition algorithm is wrong.
You takes the elements at the left which is less the the middle but takes the elements at the right which is strongly grater than middle.
You need to take at the left those elements which is greater or equal.
Then you’ll get std::partition.

But it’s not all. Quick sort still wrong.
This is because of you divide range on two parts:

First: [beg, split)
Second: [split, end)

But when you recursively call quick_sort for the first range, actually that range can be stay unchanged.

Why? Let consider this this range:

3 2 1 1 2

This range can’t be reordered because you actually reorder that elements that is less than the middle.
In this case you’ll get the infinite recursion.

To fix this you should to handle a new middle like this:

// sort left
quick_sort(beg, split);
// sort right
Iter new_middle = beg;
quick_sort(++new_middle, end);

Because if the first recurse call didn’t reorder the range.

@andrascii
Copy link

Take a look at this code: https://rextester.com/CGVRM70077

@zachwhaley
Copy link
Author

Thanks for the code review!

Do you mind if I use the code from your link to update my own code?

@andrascii
Copy link

Sure! Feel free to use it!

@andrascii
Copy link

Also you can optimize recursion count like this: https://rextester.com/AEB21840

@pedroedr
Copy link

pedroedr commented Jan 29, 2020

Hi, I tried your impl with the following and it did not work:
(A) 1 12 5 26 7 14 3 7 2; it produced: 1 2 3 7 5 7 12 14 26
(B) 3 42 1 5 32 37 4356 41 135 92; it produced: 1 5 3 32 37 41 42 92 135 4356
(C) 3 2 1 5 2 12 4356 1 5 12, with mid point pivot auto pivot = beg; std::advance(pivot, std::distance(beg, end)/2);; it produced: 3 1 1 2 2 5 5 12 12 4356
(D) 3 42 1 5 32 37 4356 41 135 92, with mid point pivot auto pivot = beg; std::advance(pivot, std::distance(beg, end)/2);; it produced: 3 1 5 32 37 41 42 92 135 4356
(E) Also, due to the weak/strong asymmetry in the comparisons done in part , it can not handle std::greater ordering.
(F) Finally, this is inefficient as it ends up making several redundant comparisons:

Iter new_middle = beg;
quick_sort(++new_middle, end);

This is what I came up with and it seems to work:

template<typename It, template<typename> class Cmp = std::less>
void mysort(It beg, It end) {
    RandomIt head = beg, bend = std::prev(end), tail = bend, pivot = beg;
    std::advance(pivot, std::distance(head, tail)/2); //Comment out if it is preferred to have the pivot at 'beg'
    using val_type = typename std::iterator_traits<It>::value_type;
    val_type pivot_val = *pivot;

    //std::cout << "sort with pivot [" << pivot_val << "]:         "; print(beg, end);
    while (head <= tail) { 
        while (Cmp<val_type>()(*head, pivot_val)) ++head;
        while (Cmp<val_type>()(pivot_val, *tail)) --tail;
        if (head <= tail) {
            std::cout << "swap(h,t)  " << *head << " ↔ " << *tail << std::endl;
            std::iter_swap(head, tail);
            ++head;
            --tail;
        }
    };
    //std::cout << "post partition - head[" << *head << "], tail[" << *tail << "]:  "; print(beg, end);

    if (beg < tail) {
        mysort<It, Cmp>(beg, std::next(tail));
    }
    if (head < bend) { //Use 'before end' iterator to avoid unnecessary iteration with just the last element
        mysort<It, Cmp>(head, end);
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment