Skip to content

Instantly share code, notes, and snippets.

@so298
Last active April 24, 2024 13:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save so298/b3bee85d17695f4aad2c14eaf5ddfc0e to your computer and use it in GitHub Desktop.
Save so298/b3bee85d17695f4aad2c14eaf5ddfc0e to your computer and use it in GitHub Desktop.
OpenMP cilksort (efficient parallel sort for fork-join model)
/*
* Run `cc -Wall -Wextra -fopenmp -O3 cilksort.c -o cilksort` to compile.
* Recommend to use clang over gcc for performance.
*/
/*
* Original code from the Cilk project
*
* Copyright (c) 2000 Massachusetts Institute of Technology
* Copyright (c) 2000 Matteo Frigo
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
*/
/*
* this program uses an algorithm that we call `cilksort'.
* The algorithm is essentially mergesort:
*
* cilksort(in[1..n]) =
* spawn cilksort(in[1..n/2], tmp[1..n/2])
* spawn cilksort(in[n/2..n], tmp[n/2..n])
* sync
* spawn cilkmerge(tmp[1..n/2], tmp[n/2..n], in[1..n])
*
*
* The procedure cilkmerge does the following:
*
* cilkmerge(A[1..n], B[1..m], C[1..(n+m)]) =
* find the median of A \union B using binary
* search. The binary search gives a pair
* (ma, mb) such that ma + mb = (n + m)/2
* and all elements in A[1..ma] are smaller than
* B[mb..m], and all the B[1..mb] are smaller
* than all elements in A[ma..n].
*
* spawn cilkmerge(A[1..ma], B[1..mb], C[1..(n+m)/2])
* spawn cilkmerge(A[ma..m], B[mb..n], C[(n+m)/2 .. (n+m)])
* sync
*
* The algorithm appears for the first time (AFAIK) in S. G. Akl and
* N. Santoro, "Optimal Parallel Merging and Sorting Without Memory
* Conflicts", IEEE Trans. Comp., Vol. C-36 No. 11, Nov. 1987 . The
* paper does not express the algorithm using recursion, but the
* idea of finding the median is there.
*
* For cilksort of n elements, T_1 = O(n log n) and
* T_\infty = O(log^3 n). There is a way to shave a
* log factor in the critical path (left as homework).
*/
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define CUTOFF (4096)
#define SWAP(type, x, y) \
{ \
type t = x; \
x = y; \
y = t; \
}
// Type definitions
typedef int elem_t; /* Element type */
typedef elem_t *elem_p; /* Element pointer type */
typedef int idx_t; /* Index type */
elem_t compare_elem_t(const void *a, const void *b) {
return *(elem_t *)a - *(elem_t *)b;
}
int is_sorted(elem_t *arr, idx_t n) {
int sorted = 1;
#pragma omp parallel for reduction(&& : sorted)
for (idx_t i = 1; i < n; i++) {
if (arr[i - 1] > arr[i]) {
sorted = 0;
}
}
return sorted;
}
// Sequential merge
void merge(elem_p a, idx_t a_len, elem_p b, idx_t b_len, elem_p dst) {
idx_t i = 0; // index for a
idx_t j = 0; // index for b
idx_t ti = 0;
while (i < a_len && j < b_len) {
if (a[i] < b[j])
dst[ti++] = a[i++];
else
dst[ti++] = b[j++];
}
while (i < a_len)
dst[ti++] = a[i++];
while (j < b_len)
dst[ti++] = b[j++];
}
// return the largest index s.t. x[index] <= val
idx_t binary_search(elem_p x, idx_t len, elem_t val) {
idx_t low = 0;
idx_t high = len;
while (low + 1 < high) {
idx_t mid = (low + high) / 2;
if (x[mid] <= val)
low = mid;
else
high = mid;
}
return low;
}
// Parallel merge function
void cilkmerge(elem_p a, idx_t a_len, elem_p b, idx_t b_len, elem_p dst) {
if (a_len < b_len) {
// Make sure a_len >= b_len
SWAP(elem_p, a, b);
SWAP(idx_t, a_len, b_len);
}
if (b_len == 0) {
memcpy(dst, a, a_len * sizeof(elem_t));
return;
}
// Single-threaded merge for small arrays
if (a_len + b_len < CUTOFF) {
merge(a, a_len, b, b_len, dst);
return;
}
idx_t a_split = (a_len + 1) / 2;
idx_t b_split = binary_search(b, b_len, a[a_split - 1]) + 1;
#pragma omp task shared(a, a_split, b, b_split, dst)
cilkmerge(a, a_split, b, b_split, dst);
#pragma omp task shared(a, a_split, b, b_split, dst)
{
elem_p a2 = a + a_split;
elem_p b2 = b + b_split;
idx_t a2_len = a_len - a_split;
idx_t b2_len = b_len - b_split;
elem_p dst2 = dst + a_split + b_split;
cilkmerge(a2, a2_len, b2, b2_len, dst2);
}
#pragma omp taskwait
}
void cilksort(elem_p arr, idx_t n, elem_p tmp) {
if (n < 2)
return;
// Single-threaded sort for small arrays
if (n < CUTOFF) {
qsort(arr, n, sizeof(elem_t), compare_elem_t);
return;
}
idx_t len12 = n / 2;
idx_t len1 = len12 / 2;
idx_t len2 = len12 - len1;
elem_p a1 = arr, b1 = tmp;
elem_p a2 = a1 + len1, b2 = b1 + len1;
idx_t len34 = n - len12;
idx_t len3 = len34 / 2;
idx_t len4 = len34 - len3;
elem_p a3 = arr + len12, b3 = tmp + len12;
elem_p a4 = a3 + len3, b4 = b3 + len3;
#pragma omp task shared(a1, len1, b1)
cilksort(a1, len1, b1);
#pragma omp task shared(a2, len2, b2)
cilksort(a2, len2, b2);
#pragma omp task shared(a3, len3, b3)
cilksort(a3, len3, b3);
#pragma omp task shared(a4, len4, b4)
cilksort(a4, len4, b4);
#pragma omp taskwait
cilkmerge(a1, len1, a2, len2, b1);
cilkmerge(a3, len3, a4, len4, b3);
#pragma omp taskwait
cilkmerge(b1, len12, b3, len34, arr);
}
void print_array(elem_p array, idx_t size) {
for (idx_t i = 0; i < size; i++) {
printf("%d ", array[i]);
}
printf("\n");
}
int main(int argc, char *argv[]) {
if (argc != 2) {
fprintf(stderr, "Usage: %s <array size>\n", argv[0]);
return 1;
}
idx_t array_size = atoi(argv[1]);
elem_p data = (elem_p)malloc(array_size * sizeof(elem_t));
elem_p tmp = (elem_p)malloc(array_size * sizeof(elem_t));
if (!data || !tmp) {
fprintf(stderr, "Memory allocation failed.\n");
return 1;
}
// Initialize array with random data
srand(0);
for (idx_t i = 0; i < array_size; i++) {
data[i] = (elem_t)(rand() % array_size);
}
// Print the array before sorting
// printf("Array before sorting:\n");
// print_array(data, array_size);
double start_time = omp_get_wtime();
#pragma omp parallel
{
#pragma omp master
cilksort(data, array_size, tmp);
}
double end_time = omp_get_wtime();
// Print the array after sorting
// printf("Array after sorting:\n");
// print_array(data, array_size);
// Check and print the result
if (!is_sorted(data, array_size)) {
fprintf(stderr, "Error: Array is not sorted\n");
}
printf("Sorting completed in %.6f seconds.\n", end_time - start_time);
free(data);
free(tmp);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment