Skip to content

Instantly share code, notes, and snippets.

@pyricau
Created May 8, 2012 20:09
Show Gist options
  • Save pyricau/2638920 to your computer and use it in GitHub Desktop.
Save pyricau/2638920 to your computer and use it in GitHub Desktop.
Implementation of various sorting algorithms in Java
import static java.util.Arrays.asList;
import static org.fest.assertions.Assertions.assertThat;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Random;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class SortAlgoTest {
public interface Sort {
int[] sort(int[] input);
}
public static abstract class InPlaceSort implements Sort {
@Override
public int[] sort(int[] input) {
sortInPlace(input);
return input;
}
protected abstract void sortInPlace(int[] input);
}
public static class BubbleSort extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
boolean sorted;
do {
sorted = true;
for (int i = 1; i < input.length; i++) {
if (input[i] < input[i - 1]) {
swap(input, i, i - 1);
sorted = false;
}
}
} while (!sorted);
}
}
public static class SelectionSort extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
for (int i = 0; i < input.length - 1; i++) {
int min = i;
for (int j = i + 1; j < input.length; j++) {
if (input[j] < input[min]) {
min = j;
}
}
if (min != i) {
swap(input, i, min);
}
}
}
}
public static abstract class InsertionSort extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
for (int i = 1; i < input.length; i++) {
int insertedItem = input[i];
int insertionIndex = findInsertionIndex(input, insertedItem, 0, i - 1);
insertItem(input, i, insertedItem, insertionIndex);
}
}
private void insertItem(int[] input, int end, int insertedItem, int insertionIndex) {
for (int j = end; j > insertionIndex; j--) {
input[j] = input[j - 1];
}
input[insertionIndex] = insertedItem;
}
protected abstract int findInsertionIndex(int[] input, int item, int start, int end);
}
public static class SimpleInsertionSort extends InsertionSort {
@Override
protected int findInsertionIndex(int[] input, int item, int start, int end) {
int insertedItem = item;
for (int j = 0; j <= end; j++) {
if (insertedItem < input[j]) {
return j;
}
}
return end + 1;
}
}
public static class InsertionSortWithRecursiveBinarySearch extends InsertionSort {
@Override
protected int findInsertionIndex(int[] input, int item, int start, int end) {
if (end < start) {
return start;
} else {
int middle = (start + end) / 2;
int middleItem = input[middle];
if (item < middleItem) {
return findInsertionIndex(input, item, start, middle - 1);
} else if (item > middleItem) {
return findInsertionIndex(input, item, middle + 1, end);
} else {
return middle;
}
}
}
}
public static class InsertionSortWithIterativeBinarySearch extends InsertionSort {
@Override
protected int findInsertionIndex(int[] input, int item, int start, int end) {
while (start <= end) {
int middle = (start + end) / 2;
int middleItem = input[middle];
if (item < middleItem) {
end = middle - 1;
} else if (item > middleItem) {
start = middle + 1;
} else {
return middle;
}
}
return start;
}
}
public static class RecursiveQuickSort extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
quickSort(input, 0, input.length - 1);
}
private void quickSort(int[] input, int start, int end) {
if (start < end) {
int pivot = (start + end) / 2;
pivot = partition(input, start, end, pivot);
quickSort(input, start, pivot - 1);
quickSort(input, pivot + 1, end);
}
}
private int partition(int[] input, int start, int end, int pivot) {
swap(input, pivot, end);
int newPivot = start;
for (int i = start; i < end; i++) {
if (input[i] < input[end]) {
swap(input, i, newPivot);
newPivot++;
}
}
swap(input, newPivot, end);
return newPivot;
}
}
public static class RecursiveQuickSortWithSelectionSortOnSmallArrays extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
quickSort(input, 0, input.length - 1);
}
private void quickSort(int[] input, int start, int end) {
if (start < end) {
if (end - start > 15) {
int pivot = (start + end) / 2;
pivot = partition(input, start, end, pivot);
quickSort(input, start, pivot - 1);
quickSort(input, pivot + 1, end);
} else {
selectionSort(input, start, end);
}
}
}
private int partition(int[] input, int start, int end, int pivot) {
swap(input, pivot, end);
int newPivot = start;
for (int i = start; i < end; i++) {
if (input[i] < input[end]) {
swap(input, i, newPivot);
newPivot++;
}
}
swap(input, newPivot, end);
return newPivot;
}
private void selectionSort(int[] input, int start, int end) {
for (int i = start; i < end; i++) {
int min = i;
for (int j = i + 1; j < end + 1; j++) {
if (input[j] < input[min]) {
min = j;
}
}
if (min != i) {
swap(input, i, min);
}
}
}
}
public static class RecursiveMergeSort extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
mergeSort(input, 0, input.length);
}
private void mergeSort(int[] input, int startInclusive, int endExclusive) {
if (startInclusive < endExclusive - 1) {
int middle = (startInclusive + endExclusive) / 2;
mergeSort(input, startInclusive, middle);
mergeSort(input, middle, endExclusive);
merge(input, startInclusive, middle, endExclusive);
}
}
private void merge(int[] input, int startInclusive, int middle, int endExclusive) {
int[] left = Arrays.copyOfRange(input, startInclusive, middle);
int[] right = Arrays.copyOfRange(input, middle, endExclusive);
int leftIndex = 0;
int rightIndex = 0;
int index = startInclusive;
while (index < endExclusive) {
if (leftIndex < left.length) {
if (rightIndex < right.length) {
if (left[leftIndex] < right[rightIndex]) {
input[index] = left[leftIndex];
leftIndex++;
} else {
input[index] = right[rightIndex];
rightIndex++;
}
} else {
input[index] = left[leftIndex];
leftIndex++;
}
} else {
input[index] = right[rightIndex];
rightIndex++;
}
index++;
}
}
}
public static class HeapSort extends InPlaceSort {
@Override
protected void sortInPlace(int[] input) {
buildMaxHeap(input);
for (int i = input.length - 1; i > 0; i--) {
swap(input, i, 0);
maxHeapify(input, 0, i);
}
}
private void buildMaxHeap(int[] input) {
for (int i = input.length / 2; i >= 0; i--) {
maxHeapify(input, i, input.length);
}
}
private void maxHeapify(int[] input, int root, int heapSize) {
int left = 2 * root;
int right = 2 * root + 1;
int largest;
if (left < heapSize && input[left] > input[root]) {
largest = left;
} else {
largest = root;
}
if (right < heapSize && input[right] > input[largest]) {
largest = right;
}
if (largest != root) {
swap(input, largest, root);
maxHeapify(input, largest, heapSize);
}
}
}
private static final Random RANDOM = new Random();
private static final int ARRAY_SIZE = 100;
private static final int[] POSITIVE_SORTED_ARRAY = buildRandomSortedArray(ARRAY_SIZE, 0, Integer.MAX_VALUE);
private static final int[] NEGATIVE_SORTED_ARRAY = buildRandomSortedArray(ARRAY_SIZE, Integer.MIN_VALUE + 1, 0);
private static final int[] SORTED_ARRAY = buildRandomSortedArray(ARRAY_SIZE, (Integer.MIN_VALUE + 1) / 2, Integer.MAX_VALUE / 2);
private static final int[] POSITIVE_SHUFFLED_ARRAY = copyShuffledArray(POSITIVE_SORTED_ARRAY);
private static final int[] NEGATIVE_SHUFFLED_ARRAY = copyShuffledArray(NEGATIVE_SORTED_ARRAY);
private static final int[] SHUFFLED_ARRAY = copyShuffledArray(SORTED_ARRAY);
private static final int[] ONE_ELEMENT_ARRAY = Arrays.copyOf(POSITIVE_SHUFFLED_ARRAY, 1);
@Parameters
public static Collection<Object[]> generateTestCases() {
Object[][] testCases = { //
//
{ new BubbleSort() }, //
{ new SelectionSort() }, //
{ new SimpleInsertionSort() }, //
{ new InsertionSortWithRecursiveBinarySearch() }, //
{ new InsertionSortWithIterativeBinarySearch() }, //
{ new RecursiveQuickSort() }, //
{ new RecursiveQuickSortWithSelectionSortOnSmallArrays() }, //
{ new RecursiveMergeSort() }, //
{ new HeapSort() }, //
};
return asList(testCases);
}
private static void swap(int[] input, int i, int j) {
if (i != j) {
int tmp = input[i];
input[i] = input[j];
input[j] = tmp;
}
}
/**
* Based on {@link Collections#shuffle(java.util.List)}
* <p>
*
* This implementation traverses the list backwards, from the last element
* up to the second, repeatedly swapping a randomly selected element into
* the "current position". Elements are randomly selected from the portion
* of the list that runs from the first element to the current position,
* inclusive.
* <p>
*/
private static int[] copyShuffledArray(int[] sortedArray) {
int[] shuffledArray = copyOf(sortedArray);
for (int i = sortedArray.length; i > 1; i--) {
swap(shuffledArray, i - 1, RANDOM.nextInt(i));
}
return shuffledArray;
}
private static int[] buildRandomSortedArray(int size, int minValue, int maxValue) {
int[] array = new int[size];
int range = maxValue - minValue;
int rangePerStep = range / size;
int previous = minValue;
for (int i = 0; i < size; i++) {
array[i] = previous + RANDOM.nextInt(rangePerStep);
previous = array[i];
}
return array;
}
private static int[] copyOf(int[] original) {
return Arrays.copyOf(original, original.length);
}
private final Sort algorithm;
public SortAlgoTest(Sort algorithm) {
this.algorithm = algorithm;
}
@Test
public void can_sort_sorted_array() {
canSort(SORTED_ARRAY, SORTED_ARRAY);
}
@Test
public void can_sort_positive_sorted_array() {
canSort(POSITIVE_SORTED_ARRAY, POSITIVE_SORTED_ARRAY);
}
@Test
public void can_sort_negative_sorted_array() {
canSort(NEGATIVE_SORTED_ARRAY, NEGATIVE_SORTED_ARRAY);
}
@Test
public void can_sort_shuffled_array() {
canSort(SHUFFLED_ARRAY, SORTED_ARRAY);
}
@Test
public void can_sort_positive_shuffled_array() {
canSort(POSITIVE_SHUFFLED_ARRAY, POSITIVE_SORTED_ARRAY);
}
@Test
public void can_sort_negative_shuffled_array() {
canSort(NEGATIVE_SHUFFLED_ARRAY, NEGATIVE_SORTED_ARRAY);
}
@Test
public void can_sort_empty_array() {
int[] emptyArray = {};
canSort(emptyArray, emptyArray);
}
@Test
public void can_sort_array_with_one_element() {
canSort(ONE_ELEMENT_ARRAY, ONE_ELEMENT_ARRAY);
}
private void canSort(int[] initialArray, int[] expectedArray) {
int[] input = copyOf(initialArray);
int[] output = algorithm.sort(input);
assertThat(output).isEqualTo(expectedArray);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment