Skip to content

Instantly share code, notes, and snippets.

@FrancisMurillo
Created August 8, 2020 06:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancisMurillo/d09050ec122a77d9cc91a3a145381e97 to your computer and use it in GitHub Desktop.
Save FrancisMurillo/d09050ec122a77d9cc91a3a145381e97 to your computer and use it in GitHub Desktop.
Rust Median of Medians
#[cfg(test)]
#[macro_use(quickcheck)]
extern crate quickcheck_macros;
pub fn median<I: Ord + std::fmt::Debug>(items: &mut Vec<I>, index: usize) -> I {
assert!(index < items.len());
items.sort_unstable();
items.remove(index)
}
pub fn median_of_medians<I: Ord + std::fmt::Debug>(items: &mut Vec<I>, mut index: usize) -> I {
assert!(!items.is_empty());
assert!(index < items.len());
let items_len = items.len();
if items_len <= 5 {
items.sort_unstable();
return items.remove(index);
}
let item_ptr: *mut Vec<I> = items;
let median_items = unsafe { &mut *item_ptr };
let mut item_iter = median_items.drain(..);
let mut subgroups = Vec::with_capacity(items_len >> 2);
let mut medians = Vec::with_capacity(items_len >> 2);
while let Some(first_item) = item_iter.next() {
let mut subgroup = Vec::with_capacity(5);
subgroup.push(first_item);
if let Some(second_item) = item_iter.next() {
subgroup.push(second_item);
}
if let Some(third_item) = item_iter.next() {
subgroup.push(third_item);
}
if let Some(fourth_item) = item_iter.next() {
subgroup.push(fourth_item);
}
if let Some(fifth_item) = item_iter.next() {
subgroup.push(fifth_item);
}
subgroup.sort_unstable();
let median = subgroup.remove(subgroup.len() >> 1);
medians.push(median);
subgroups.push(subgroup);
}
let medians_len = medians.len();
let pivot: I = if medians_len <= 5 {
medians.sort_unstable();
medians.remove(medians_len >> 1)
} else {
median_of_medians(&mut medians, medians_len >> 1)
};
let mut lower = Vec::with_capacity(items_len);
let mut pivots = Vec::with_capacity(items_len);
let mut upper = Vec::with_capacity(items_len);
for subgroup in subgroups.drain(..) {
for item in subgroup {
if item < pivot {
lower.push(item);
} else if item > pivot {
upper.push(item);
} else {
pivots.push(item);
}
}
}
for item in medians.drain(..) {
if item < pivot {
lower.push(item);
} else if item > pivot {
upper.push(item);
} else {
pivots.push(item);
}
}
let median = {
if index < lower.len() {
pivots.push(pivot);
median_of_medians(&mut lower, index)
} else {
index -= lower.len();
if index < pivots.len() + 1 {
pivot
} else {
index -= pivots.len() + 1;
pivots.push(pivot);
median_of_medians(&mut upper, index)
}
}
};
items.append(&mut lower);
items.append(&mut pivots);
items.append(&mut upper);
median
}
#[inline]
fn swap_less<I: Ord + std::fmt::Debug>(
items: &mut Vec<I>,
indices: &mut Vec<usize>,
left: usize,
right: usize,
) {
if items[indices[left]] > items[indices[right]] {
indices.swap(left, right);
}
}
#[inline]
fn sort_two_at<I: Ord + std::fmt::Debug>(
items: &mut Vec<I>,
indices: &mut Vec<usize>,
index: usize,
) {
swap_less(items, indices, index, index + 1);
}
#[inline]
fn sort_three_at<I: Ord + std::fmt::Debug>(
items: &mut Vec<I>,
indices: &mut Vec<usize>,
index: usize,
) {
swap_less(items, indices, index, index + 1);
swap_less(items, indices, index, index + 2);
sort_two_at(items, indices, index + 1);
}
#[inline]
fn sort_four_at<I: Ord + std::fmt::Debug>(
items: &mut Vec<I>,
indices: &mut Vec<usize>,
index: usize,
) {
swap_less(items, indices, index, index + 1);
swap_less(items, indices, index, index + 2);
swap_less(items, indices, index, index + 3);
sort_three_at(items, indices, index + 1);
}
#[inline]
fn sort_five_at<I: Ord + std::fmt::Debug>(
items: &mut Vec<I>,
indices: &mut Vec<usize>,
index: usize,
) {
swap_less(items, indices, index, index + 1);
swap_less(items, indices, index, index + 2);
swap_less(items, indices, index, index + 3);
swap_less(items, indices, index, index + 4);
sort_four_at(items, indices, index + 1);
}
#[inline]
fn sort_at<I: Ord + std::fmt::Debug>(
items: &mut Vec<I>,
indices: &mut Vec<usize>,
index: usize,
) -> usize {
assert!(index < items.len());
let diff = indices.len() - index;
if diff < 2 {
// NOOP
} else if diff == 2 {
sort_two_at(items, indices, index);
} else if diff == 3 {
sort_three_at(items, indices, index);
} else if diff == 4 {
sort_four_at(items, indices, index);
} else {
sort_five_at(items, indices, index);
}
if diff >= 5 {
5
} else {
diff
}
}
pub fn median_of_medians_faster<I: Ord + std::fmt::Debug>(items: &mut Vec<I>, index: usize) -> I {
assert!(!items.is_empty());
assert!(index < items.len());
let median_index = indexed_median(items, &mut (0..items.len()).collect(), index);
items.remove(median_index)
}
fn indexed_median<I: Ord + std::fmt::Debug>(
mut items: &mut Vec<I>,
mut indices: &mut Vec<usize>,
mut index: usize,
) -> usize {
let indices_len = indices.len();
if indices_len <= 5 {
sort_at(items, &mut indices, 0);
return indices[index];
}
let mut median_indices = Vec::with_capacity(indices_len >> 2);
(0..indices.len()).step_by(5).for_each(|start| {
let diff = sort_at(items, &mut indices, start);
median_indices.push(indices[start + (diff >> 1)]);
});
let median_indices_len = median_indices.len();
let pivot_index = if median_indices_len <= 5 {
sort_at(items, &mut median_indices, 0);
median_indices[median_indices_len >> 1]
} else {
indexed_median(&mut items, &mut median_indices, median_indices_len >> 1)
};
let mut lower_indices = Vec::new();
let mut pivot_indices = Vec::new();
let mut upper_indices = Vec::new();
for item_index in indices.iter().copied() {
if items[item_index] < items[pivot_index] {
lower_indices.push(item_index);
} else if items[item_index] > items[pivot_index] {
upper_indices.push(item_index);
} else {
pivot_indices.push(item_index);
}
}
if index < lower_indices.len() {
indexed_median(items, &mut lower_indices, index)
} else {
index -= lower_indices.len();
if index < pivot_indices.len() {
pivot_index
} else {
index -= pivot_indices.len();
indexed_median(items, &mut upper_indices, index)
}
}
}
#[cfg(test)]
mod tests {
use quickcheck::TestResult;
use super::*;
#[test]
fn median_works() {
let mut xs = vec![1, 2, 3, 4, 5, 1000, 8, 9, 99];
let mut ys = vec![1, 2, 3, 4, 5, 6];
assert_eq!(1, median(&mut xs.clone(), 0));
assert_eq!(99, median(&mut xs, 7));
assert_eq!(5, median(&mut ys, 4));
}
#[test]
fn median_of_medians_works() {
let mut xs = vec![1, 2, 3, 4, 5, 1000, 8, 9, 99];
let mut ys = vec![1, 2, 3, 4, 5, 6];
assert_eq!(1, median_of_medians(&mut xs.clone(), 0));
assert_eq!(99, median_of_medians(&mut xs, 7));
assert_eq!(5, median_of_medians(&mut ys, 4));
}
#[test]
fn median_of_medians_faster_works() {
let mut xs = vec![1, 2, 3, 4, 5, 1000, 8, 9, 99];
let mut ys = vec![1, 2, 3, 4, 5, 6];
assert_eq!(1, median_of_medians_faster(&mut xs.clone(), 0));
assert_eq!(99, median_of_medians_faster(&mut xs, 7));
assert_eq!(5, median_of_medians_faster(&mut ys, 4));
}
#[quickcheck]
fn should_work(mut xs: Vec<usize>, index: usize) -> TestResult {
if xs.is_empty() || index >= xs.len() {
return TestResult::discard();
}
let mut ys = xs.clone();
if median(&mut ys, index) != median_of_medians(&mut xs, index) {
return TestResult::failed();
}
xs.sort();
ys.sort();
TestResult::from_bool(xs == ys)
}
#[quickcheck]
fn faster_should_work(mut xs: Vec<usize>, index: usize) -> TestResult {
if xs.is_empty() || index >= xs.len() {
return TestResult::discard();
}
let mut ys = xs.clone();
if median(&mut ys, index) != median_of_medians_faster(&mut xs, index) {
return TestResult::failed();
}
xs.sort();
ys.sort();
TestResult::from_bool(xs == ys)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment