Skip to content

Instantly share code, notes, and snippets.

@grisuthedragon
Last active March 27, 2024 18:41
Show Gist options
  • Save grisuthedragon/0fa99935086a5945171ef63f185bbcee to your computer and use it in GitHub Desktop.
Save grisuthedragon/0fa99935086a5945171ef63f185bbcee to your computer and use it in GitHub Desktop.
Demonstrator for starpu-runtime #37
#include <stdio.h>
#include <stdlib.h>
#include <starpu.h>
#include <limits.h>
#include <math.h>
#ifdef STARPU_USE_CUDA
#include <cublas_v2.h>
#include <cuda.h>
#include <starpu_cublas_v2.h>
#endif
void dlarnv_(int * idist, int *iseed, int *, double *X);
void dgemm_(char *ta, char *tb, int *m, int *n, int *k, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc);
void dgemm(char ta, char tb, int m, int n, int k, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc) {
char TA[2] = {ta, 0};
char TB[2] = {tb, 0};
dgemm_(TA, TB, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
return ;
}
void gemm_kernel_func (void * buffers[], void *args){
int m = STARPU_MATRIX_GET_NX(buffers[2]);
int n = STARPU_MATRIX_GET_NY(buffers[2]);
int k = STARPU_MATRIX_GET_NY(buffers[0]);
int ldA = STARPU_MATRIX_GET_LD(buffers[0]);
int ldB = STARPU_MATRIX_GET_LD(buffers[1]);
int ldC = STARPU_MATRIX_GET_LD(buffers[2]);
double *A = (double *) STARPU_MATRIX_GET_PTR(buffers[0]);
double *B = (double *) STARPU_MATRIX_GET_PTR(buffers[1]);
double *C = (double *) STARPU_MATRIX_GET_PTR(buffers[2]);
double alpha, beta;
starpu_codelet_unpack_args(args, &alpha, &beta);
/* printf("alpha = %lg\n", alpha ); */
dgemm('N','N', m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
}
#ifdef STARPU_USE_CUDA
static void cublas_mult(void *descr[], void *arg)
{
(void)arg;
double *subA = (double *)STARPU_MATRIX_GET_PTR(descr[0]);
double *subB = (double *)STARPU_MATRIX_GET_PTR(descr[1]);
double *subC = (double *)STARPU_MATRIX_GET_PTR(descr[2]);
unsigned nxC = STARPU_MATRIX_GET_NX(descr[2]);
unsigned nyC = STARPU_MATRIX_GET_NY(descr[2]);
unsigned nyA = STARPU_MATRIX_GET_NY(descr[0]);
unsigned ldA = STARPU_MATRIX_GET_LD(descr[0]);
unsigned ldB = STARPU_MATRIX_GET_LD(descr[1]);
unsigned ldC = STARPU_MATRIX_GET_LD(descr[2]);
double alpha, beta;
starpu_codelet_unpack_args(arg, &alpha, &beta);
double v0=0;
cublasStatus_t status = cublasDgemm(starpu_cublas_get_local_handle(),
CUBLAS_OP_N, CUBLAS_OP_N,
nxC, nyC, nyA,
&alpha, subA, ldA, subB, ldB,
&beta, subC, ldC);
if (status != CUBLAS_STATUS_SUCCESS)
STARPU_CUBLAS_REPORT_ERROR(status);
}
#endif
static struct starpu_perfmodel gemm_perf_model =
{
.type = STARPU_NL_REGRESSION_BASED,
.symbol = "gemm_perf_model"
};
struct starpu_codelet gemm_kernel_cl = {
.max_parallelism = INT_MAX,
.cpu_funcs = { gemm_kernel_func },
.cpu_funcs_name = { "gemm_kernel" },
#ifdef STARPU_USE_CUDA
.cuda_funcs = {cublas_mult},
.cuda_flags = {STARPU_CUDA_ASYNC},
#endif
.nbuffers = 3,
.modes = { STARPU_R, STARPU_R, STARPU_RW },
.name = "gemm_kernel",
.model = &gemm_perf_model
};
struct starpu_codelet gemm0_kernel_cl = {
.max_parallelism = INT_MAX,
.cpu_funcs = { gemm_kernel_func },
.cpu_funcs_name = { "gemm_kernel" },
#ifdef STARPU_USE_CUDA
.cuda_funcs = {cublas_mult},
.cuda_flags = {STARPU_CUDA_ASYNC},
#endif
.nbuffers = 3,
.modes = { STARPU_R, STARPU_R, STARPU_W },
.name = "gemm_kernel",
.model = &gemm_perf_model
};
int main(int argc, char *argv[])
{
starpu_init(NULL);
starpu_cublas_init();
double start, end;
int m, n, k;
int i, j, l;
int iseed[4] = {1,1,1,1};
int idist = 2;
int MB = 32;
int NB = 32;
int KB = 32;
/* int xparts = 16; */
/* int yparts = 16; */
/* int kparts = 8; */
int s,z;
double *A, *B, *C1, *C2;
double alpha, beta;
m = 1000;
n = 1000;
k = 1000;
int xparts = (m + MB - 1) / MB;
int yparts = (n + NB - 1) / NB;
int kparts = (k + KB - 1) / KB;
alpha = 1;
beta = 0;
A = malloc(sizeof(double) * m * k);
B = malloc(sizeof(double) * k * n);
C1= malloc(sizeof(double) * m * n);
C2= malloc(sizeof(double) * m * n);
starpu_memory_pin( A, sizeof(double)*m*k);
starpu_memory_pin( B, sizeof(double)*n*k);
starpu_memory_pin( C1, sizeof(double)*m*n);
starpu_memory_pin( C2, sizeof(double)*m*n);
i = m * k;
dlarnv_(&idist, iseed, &i, A);
i = k * n;
dlarnv_(&idist, iseed, &i, B);
printf("Start... \n");
dgemm('N', 'N', m, n, k, alpha, A, m, B, k, beta, C1, m);
/* Star PU */
starpu_data_handle_t handleA;
starpu_data_handle_t handleB;
starpu_data_handle_t handleC;
starpu_matrix_data_register(&handleA, STARPU_MAIN_RAM, (uintptr_t)A, m, k, m, sizeof(double)) ;
starpu_matrix_data_register(&handleB, STARPU_MAIN_RAM, (uintptr_t)B, k, n, k, sizeof(double)) ;
starpu_matrix_data_register(&handleC, STARPU_MAIN_RAM, (uintptr_t)C2, m, n, m, sizeof(double)) ;
/* Partition A */
starpu_data_handle_t handleA_vert[xparts];
starpu_data_handle_t handleA_part[xparts][kparts];
struct starpu_data_filter filterA_rows = {
.filter_func = starpu_matrix_filter_block,
.nchildren = xparts
};
struct starpu_data_filter filterA_columns = {
.filter_func = starpu_matrix_filter_vertical_block,
.nchildren = kparts
};
starpu_data_partition_plan(handleA, &filterA_rows, handleA_vert);
for (int i = 0; i < xparts; i++)
starpu_data_partition_plan(handleA_vert[i], &filterA_columns, handleA_part[i]);
/* Partition B */
starpu_data_handle_t handleB_vert[kparts];
starpu_data_handle_t handleB_part[kparts][yparts];
struct starpu_data_filter filterB_rows = {
.filter_func = starpu_matrix_filter_block,
.nchildren = kparts
};
struct starpu_data_filter filterB_columns = {
.filter_func = starpu_matrix_filter_vertical_block,
.nchildren = yparts
};
starpu_data_partition_plan(handleB, &filterB_rows, handleB_vert);
for (int i = 0; i < kparts; i++)
starpu_data_partition_plan(handleB_vert[i], &filterB_columns, handleB_part[i]);
/* Partition C */
starpu_data_handle_t handleC_vert[xparts];
starpu_data_handle_t handleC_part[xparts][yparts];
struct starpu_data_filter filterC_rows = {
.filter_func = starpu_matrix_filter_block,
.nchildren = xparts
};
struct starpu_data_filter filterC_columns = {
.filter_func = starpu_matrix_filter_vertical_block,
.nchildren = yparts
};
starpu_data_partition_plan(handleC, &filterC_rows, handleC_vert);
for (int i = 0; i <xparts; i++)
starpu_data_partition_plan(handleC_vert[i], &filterC_columns, handleC_part[i]);
start = starpu_timing_now();
alpha = 1;
for ( i = 0; i < xparts; i ++) {
for (j = 0; j < yparts; j++) {
for ( l = 0; l < kparts; l++) {
if ( l == 0 ) {
beta = 0.0;
starpu_task_insert(&gemm0_kernel_cl, STARPU_VALUE, &alpha, sizeof(double),
STARPU_VALUE, &beta, sizeof(double),
STARPU_R, handleA_part[i][l],
STARPU_R, handleB_part[l][j],
STARPU_W, handleC_part[i][j],
STARPU_NAME, "gemm_kernel",
0);
} else {
beta = 1;
starpu_task_insert(&gemm_kernel_cl, STARPU_VALUE, &alpha, sizeof(double),
STARPU_VALUE, &beta, sizeof(double),
STARPU_R, handleA_part[i][l],
STARPU_R, handleB_part[l][j],
STARPU_RW, handleC_part[i][j],
STARPU_NAME, "gemm_kernel",
0);
}
}
}
}
starpu_task_wait_for_all();
end = (starpu_timing_now() - start) /1000000.0;
double flops = (m/1000.0) * (n / 1000.0 ) * (k/1000.0) * 2.0;
printf("Time: %lg\n", end);
printf("GFlops: %lg\n", flops/end);
for (int i = 0; i < xparts; i++)
starpu_data_partition_clean(handleA_vert[i], kparts, handleA_part[i]);
starpu_data_partition_clean(handleA, xparts, handleA_vert);
for (int i = 0; i < kparts; i++)
starpu_data_partition_clean(handleB_vert[i], yparts, handleB_part[i]);
starpu_data_partition_clean(handleB, kparts, handleB_vert);
for (int i = 0; i < xparts; i++)
starpu_data_partition_clean(handleC_vert[i], yparts, handleC_part[i]);
starpu_data_partition_clean(handleC, xparts, handleC_vert);
starpu_data_unregister(handleA);
starpu_data_unregister(handleB);
starpu_data_unregister(handleC);
#if 1
/* Check */
for (j = 0; j < n; j++) {
for ( i = 0; i < m; i++) {
if ( fabs(C1[i+j*m] - C2[i+j*m]) > 1e-10) {
printf("Fail in (%d,%d) %lg - %lg \n", i, j, C1[i+j*m], C2[i+j*m] );
}
}
}
#endif
starpu_memory_unpin( A, sizeof(double)*m*k);
starpu_memory_unpin( B, sizeof(double)*n*k);
starpu_memory_unpin( C1, sizeof(double)*m*n);
starpu_memory_unpin( C2, sizeof(double)*m*n);
starpu_shutdown();
free(A);
free(B);
free(C1);
free(C2);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment