Skip to content

Instantly share code, notes, and snippets.

@elbeno
Last active November 25, 2021 03:04
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elbeno/8dbfc60722601ee761aa to your computer and use it in GitHub Desktop.
Save elbeno/8dbfc60722601ee761aa to your computer and use it in GitHub Desktop.
A sketch of inplace_merge
#include <cassert>
#include <iterator>
#include <memory>
#include <utility>
template <typename ForwardIt>
void my_inplace_merge(ForwardIt first, ForwardIt middle, ForwardIt last)
{
if (first == middle || middle == last) return;
using T = typename std::iterator_traits<ForwardIt>::value_type;
// call the range [first, middle) x, and the range [middle, last) y
// make a buffer big enough to hold the smaller range
auto d1 = std::distance(first, middle);
auto d2 = std::distance(middle, last);
// we're going to write to the range in place
ForwardIt out = first;
// we're going to read from each range
ForwardIt readx = first;
ForwardIt ready = middle;
// and we're going to read and write from tmp. tmp is a buffer than can hold
// values, but we don't know that T is default constructible, so make a char
// buffer big enough
auto n = std::min(d1, d2);
auto tmp = std::make_unique<char[]>(n * sizeof(T));
// we're going to use tmp as a ring buffer
T* begint = reinterpret_cast<T*>(tmp.get());
T* endt = begint + n;
T* readt = begint;
T* writet = begint;
bool tmp_full = false;
// loop until we reach the end of either x or y
while (readx != middle && ready != last)
{
if (readt != writet)
{
// the next lowest x value is in tmp
if (*ready < *readt)
{
// x -> tmp, y -> out
*writet++ = std::move(*readx);
if (writet == endt) writet = begint;
*out = std::move(*ready);
++ready;
tmp_full = writet == readt;
}
else
{
// x -> tmp, tmp -> out
*writet++ = std::move(*readx);
if (writet == endt) writet = begint;
*out = std::move(*readt);
++readt;
if (readt == endt) readt = begint;
}
}
else
{
// tmp is currently empty, we're comparing the x with y directly
if (*ready < *readx)
{
// x -> tmp, y -> out
*writet++ = std::move(*readx);
if (writet == endt) writet = begint;
*out = std::move(*ready);
++ready;
tmp_full = writet == readt;
}
}
++readx;
++out;
// if the temporary buffer is full, we must have exhausted either x or y
assert(!tmp_full || readx == middle || ready == last);
}
if (out == middle)
{
// any remaining values of x are now in tmp, so do a "normal" merge from tmp
// and y to out
while (readt != writet || tmp_full)
{
if (ready != last && *ready < *readt)
{
*out++ = std::move(*ready);
++ready;
}
else
{
*out++ = std::move(*readt);
++readt;
if (readt == endt) readt = begint;
tmp_full = false;
}
}
}
else if (ready == last)
{
// all the values of y are accounted for, and we know that all the values of
// x in tmp are less than the remaining ones, so swap tmp with x repeatedly
// until we reach the middle, then output the remaining tmp
using std::swap;
T* p = readt;
while (out != middle)
{
swap(*p, *out);
++p, ++out;
if (p == endt) p = begint;
if (p == writet) p = readt;
}
while (out != last)
{
*out = std::move(*p);
++p, ++out;
if (p == endt) p = begint;
if (p == writet) p = readt;
}
}
}
#include <algorithm>
template <typename ForwardIt>
void naive_inplace_merge(ForwardIt first, ForwardIt middle, ForwardIt last)
{
while (first != middle && middle != last)
{
if (*middle < *first)
{
std::iter_swap(middle, first);
auto i = middle;
std::rotate(++first, i, ++middle);
}
else
{
++first;
}
}
}
template <typename BidirIt>
void naive_inplace_merge2(BidirIt first, BidirIt middle, BidirIt last)
{
using T = typename std::iterator_traits<BidirIt>::value_type;
auto d1 = std::distance(first, middle);
auto d2 = std::distance(middle, last);
auto n = std::min(d1, d2);
auto tmp = std::make_unique<char[]>(n * sizeof(T));
T* begint = reinterpret_cast<T*>(tmp.get());
T* endt = begint + n;
if (d1 <= d2)
{
std::move(first, middle, begint);
std::merge(begint, endt, middle, last, first);
}
else
{
std::move(middle, last, begint);
auto i = std::move_backward(first, middle, last);
std::merge(i, last, begint, endt, first);
}
}
#include <forward_list>
#include <iostream>
#include <vector>
using namespace std;
int main(void)
{
// some tests
// split at 5
// vector<int> v{0,2,4,6,8,1,3,5,7,9};
// vector<int> v{0,1,2,3,4,5,6,7,8,9};
// vector<int> v{5,6,7,8,9,0,1,2,3,4};
// forward_list<int> v{5,6,7,8,9,0,1,2,3,4};
// typename decltype(v)::difference_type split = 5;
// split at 3
// vector<int> v{0,1,2,3,4,5,6,7,8,9};
// vector<int> v{7,8,9,0,1,2,3,4,5,6};
// forward_list<int> v{7,8,9,0,1,2,3,4,5,6};
// typename decltype(v)::difference_type split = 3;
// split at 7
// vector<int> v{0,1,2,3,4,5,6,7,8,9};
// vector<int> v{3,4,5,6,7,8,9,0,1,2};
forward_list<int> v{3,4,5,6,7,8,9,0,1,2};
typename decltype(v)::difference_type split = 7;
auto b = v.begin();
auto m = b;
advance(m, split);
auto e = v.end();
my_inplace_merge(b, m, e);
for (auto i : v)
cout << i << ", ";
cout << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment