Skip to content

Instantly share code, notes, and snippets.

@tgfrerer
Last active December 4, 2023 14:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tgfrerer/b11221e91793a4361d41a8550916722f to your computer and use it in GitHub Desktop.
Save tgfrerer/b11221e91793a4361d41a8550916722f to your computer and use it in GitHub Desktop.
Generate bitonic merge sort network based on number of sortable items n, and workgroup size.

This command-line utility generates a bitonic merge sort network based on number of sortable items n, and workgroup size.

Choose workgroup size to be ==n for an ideal sorting network. Number of sortable elements n must be power of two.

In-depth discussion, GPU implementation, and context for this code: https://poniesandlight.co.uk/reflect/bitonic_merge_sort/

#include <stdio.h>
#include <stdint.h>
#include <stdio.h>
#include <assert.h>
#include <malloc.h>
#include <math.h>
// the simplest possible vector implementation i can think of
typedef struct vec_t {
size_t size_allocated; // given in bytes
size_t size_occupied; //
char * mem;
} vec_t;
vec_t *vec_new() {
// we must allocate memory for the struct
vec_t *self = calloc( 1, sizeof( vec_t ) );
self->size_allocated = 2;
self->mem = calloc( self->size_allocated, sizeof( char ) );
return self;
}
void vec_delete( vec_t *vec ) {
if ( vec ) {
free( vec->mem );
}
free( vec );
}
// return pointer to vector data, optionally sets
// size to size of vec
char *vec_data( vec_t *vec, size_t *size ) {
if ( size ) {
*size = vec->size_occupied;
}
return vec->mem;
}
// return size of vec
size_t vec_size( vec_t *vec ) {
return vec->size_occupied;
}
void vec_realloc( vec_t *vec ) {
vec->mem = realloc( vec->mem, vec->size_allocated * 2 );
assert( vec->mem != NULL );
vec->size_allocated *= 2;
}
void vec_push_back( vec_t *vec, char data ) {
if ( vec->size_allocated == vec->size_occupied ) {
// we must realloc
vec_realloc( vec );
}
vec->mem[ vec->size_occupied++ ] = data;
}
void update_plot( uint32_t x, uint32_t y, size_t chars_count, vec_t **vecs, size_t h ) {
uint32_t upper;
uint32_t lower;
if ( x > y ) {
puts( "error x > y" );
upper = x;
lower = y;
} else {
upper = y;
lower = x;
}
size_t min_idx = ( lower / h ) * h;
size_t max_idx = ( ( upper + h - 1 ) / h ) * h;
if ( min_idx == max_idx ) {
puts( "grr" );
}
for ( size_t i = min_idx; i != max_idx; i++ ) {
if ( i < lower ) {
vec_push_back( vecs[ i ], ' ' );
} else if ( i == lower ) {
vec_push_back( vecs[ i ], 'w' );
} else if ( i > lower && i < upper ) {
vec_push_back( vecs[ i ], 'x' );
} else if ( i == upper ) {
vec_push_back( vecs[ i ], 'v' );
} else {
vec_push_back( vecs[ i ], ' ' );
}
}
}
void space_plot( size_t chars_count, vec_t **vecs ) {
for ( size_t i = 0; i != chars_count; i++ ) {
vec_push_back( vecs[ i ], ' ' );
vec_push_back( vecs[ i ], ' ' );
vec_push_back( vecs[ i ], ' ' );
}
}
int big_flip( uint32_t n, uint32_t h, uint32_t workgroup_size_x, vec_t **vecs ) {
// uint32_t n = 64; // total number of sortable elements
// uint32_t h = 8; // flip height
// uint32_t workgroup_size_x = 4; // number of threads in block/workgroup: each thread deals with two sortable elements
if ( workgroup_size_x * 2 > h ) {
puts( "error: number of sortable elements processed by one workgroup must be smaller or equal to flip height." );
return 1;
};
uint32_t workgroup_count = n / ( workgroup_size_x * 2 ); // number of workgroups needed
printf( "big_flip : % 3d elements. FLP over height % 3d, using % 3d workgroups\t", n, h, workgroup_count );
for ( uint32_t workgroup_id = 0; workgroup_id != workgroup_count; workgroup_id++ ) {
// We use `t` for local_invocation_id.x which represents the local thread id.
for ( uint32_t t = 0; t != workgroup_size_x; t++ ) {
uint32_t t_prime = workgroup_id * workgroup_size_x + t;
uint32_t q = ( ( 2 * t_prime ) / h ) * h;
uint32_t x = q + ( t_prime % ( h / 2 ) );
uint32_t y = q + h - ( t_prime % ( h / 2 ) ) - 1;
printf( "[% 3i,% 3i], ", x, y );
update_plot( x, y, n, vecs, h );
// counter++;
}
}
space_plot( n, vecs );
puts( "" );
return 0;
}
void local_flip( uint32_t n, uint32_t h, uint32_t workgroup_size_x, vec_t **vecs ) {
uint32_t workgroup_count = n / ( workgroup_size_x * 2 ); // number of workgroups needed
printf( "local_flip : % 3d elements. FLP over height % 3d, using % 3d workgroups\t", n, h, workgroup_count );
for ( uint32_t workgroup_id = 0; workgroup_id != workgroup_count; workgroup_id++ ) {
uint32_t h_offset = h * ( ( workgroup_size_x * workgroup_id * 2 ) / h );
// We use `t` for local_invocation_id.x which represents the local thread id.
uint32_t half_h = h / 2;
for ( uint32_t t = 0; t != workgroup_size_x; t++ ) {
uint32_t x = h_offset + h * ( ( 2 * t ) / h ) + t % half_h;
uint32_t y = h_offset + h * ( ( 2 * t ) / h ) + h - 1 - ( t % half_h );
printf( "[% 3i,% 3i], ", x, y );
update_plot( x, y, n, vecs, h );
}
}
space_plot( n, vecs );
puts( "" );
}
int big_disperse( uint32_t n, uint32_t h, uint32_t workgroup_size_x, vec_t **vecs ) {
// uint32_t n = 16; // total number of sortable elements
// uint32_t h = 16; // flip height
// uint32_t workgroup_size_x = 2; // number of threads in block/workgroup: each thread deals with two sortable elements
if ( workgroup_size_x * 2 > h ) {
puts( "error: number of sortable elements processed by one workgroup must be smaller or equal to flip height." );
return 1;
};
uint32_t workgroup_count = n / ( workgroup_size_x * 2 ); // number of workgroups needed
printf( "big_disperse : % 3d elements. DSP over height % 3d, using % 3d workgroups\t", n, h, workgroup_count );
for ( uint32_t workgroup_id = 0; workgroup_id != workgroup_count; workgroup_id++ ) {
// We use `t` for local_invocation_id.x which represents the local thread id.
for ( uint32_t t = 0; t != workgroup_size_x; t++ ) {
uint32_t t_prime = workgroup_id * workgroup_size_x + t;
uint32_t q = ( ( 2 * t_prime ) / h ) * h;
uint32_t x = q + ( t_prime % ( h / 2 ) );
uint32_t y = q + ( t_prime % ( h / 2 ) ) + ( h / 2 );
printf( "[% 3i,% 3i], ", x, y );
update_plot( x, y, n, vecs, h );
}
}
space_plot( n, vecs );
puts( "" );
return 0;
}
void local_disperse( uint32_t n, uint32_t h, uint32_t workgroup_size_x, vec_t **vecs ) {
assert( workgroup_size_x > 0 );
uint32_t workgroup_count = n / ( workgroup_size_x * 2 ); // number of workgroups needed
for ( ; h > 1; h /= 2 ) {
printf( "local_disperse: % 3d elements. DSP over height % 3d, using % 3d workgroups\t", n, h, workgroup_count );
for ( uint32_t workgroup_id = 0; workgroup_id != workgroup_count; workgroup_id++ ) {
// We use `t` for local_invocation_id.x which represents the local thread id.
uint32_t h_offset = ( ( workgroup_size_x * workgroup_id * 2 ) );
uint32_t half_h = h / 2;
for ( uint32_t t = 0; t != workgroup_size_x; t++ ) {
uint32_t x = h_offset + h * ( ( 2 * t ) / h ) + t % half_h;
uint32_t y = x + half_h;
printf( "[% 3i,% 3i], ", x, y );
update_plot( x, y, n, vecs, h );
}
}
space_plot( n, vecs );
puts( "" );
}
}
void local_bms( uint32_t n, uint32_t h, uint32_t workgroup_size_x, vec_t **vecs ) {
const uint32_t WORKGROUP_COUNT = n / ( workgroup_size_x * 2 );
printf( "local_bms : % 3d elements. BMS over height % 3d, using % 3d workgroups\n", n, h, WORKGROUP_COUNT );
for ( size_t hh = 2; hh <= h; hh *= 2 ) {
local_flip( n, hh, workgroup_size_x, vecs );
local_disperse( n, hh / 2, workgroup_size_x, vecs );
}
}
int main( void ) {
uint32_t n = 16; // total number of sortable elements
uint32_t workgroup_size_x = 2; // number of threads in block/workgroup: each thread deals with two sortable elements
uint32_t max_workgroup_size = 4; // this must be calculated based on how much data is available in shader local memory to store sortable elements.
printf( "Enter n, max_workgroup_size: \n" );
scanf( "%d, %d", &n, &max_workgroup_size );
// if ( workgroup_size_x * 2 > h ) {
// puts( "workgroup_size_x * 2 must be < h." );
// return 1;
// }
assert( n > 2 );
/*
Algorithm for bitonic merge sort over n elements:
> let n be the number of sortable elements.
> assert n to be a power of two (meaning only a single bit is set to 1)
> assert n to be greater than 1
if (n < max_workgroup_size * 2 ){
workgroup_size = n / 2;
} else {
workgroup_size = max_workgroup_size;
}
> let workgroup_count = n / (workgroup_size * 2)
> local bms over workgroupsize*2 elements, issued workgroup_count times
if (workgroup_count == 1 )
COMPLETE.
else
> (while h < n) {
h *= 2 // double h
big_flip(h) // big flip over h
for (H = h/2; H > 1; h/=2){
if (H <= max_workgroup_size * 2 ){
local_disperse(H);
break;
} else {
big_disperse(H)
}
}
}
*/
vec_t **vecs = calloc( n, sizeof( vec_t * ) );
for ( size_t i = 0; i != n; i++ ) {
vecs[ i ] = vec_new();
}
if ( n < max_workgroup_size * 2 ) {
workgroup_size_x = n / 2;
} else {
workgroup_size_x = max_workgroup_size;
}
const uint32_t WORKGROUP_COUNT = n / ( workgroup_size_x * 2 );
uint32_t h = workgroup_size_x * 2;
assert( h <= n );
assert( h % 2 == 0 );
local_bms( n, h, workgroup_size_x, vecs );
fflush( stdout );
// if ( WORKGROUP_COUNT == 1 ) {
// return 0;
// }
// ----------| invariant: h == workgroup_size_x * 2
// we must now double h, as this happens before every flip
h *= 2;
for ( ; h <= n; h *= 2 ) {
big_flip( n, h, workgroup_size_x, vecs );
fflush( stdout );
for ( uint32_t hh = h / 2; hh > 1; hh /= 2 ) {
if ( hh <= workgroup_size_x * 2 ) {
// We can fit all elements for a disperse operation into continuous shader
// workgroup local memory, which means we can complete the rest of the cascade
// using a single shader invocation.
// todo: calculate offset for each workgroup: based on hh, and workgroup_id
local_disperse( n, hh, workgroup_size_x, vecs );
break;
fflush( stdout );
} else {
big_disperse( n, hh, workgroup_size_x, vecs );
fflush( stdout );
}
}
}
// ----------------------------------------------------------------------
// print plot
// ----------------------------------------------------------------------
puts( "" );
for ( size_t i = 0; i != n; i++ ) {
int num_chars = 0;
size_t line_length = 0;
char * line = vec_data( vecs[ i ], &line_length );
num_chars = line_length;
printf( "\x1b(0%.*s\x1b(B\n", num_chars, line );
}
fflush( stdout );
// cleanup vecs used for debug_printout
for ( size_t i = 0; i != n; i++ ) {
vec_delete( vecs[ i ] );
}
free( vecs );
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment