Created
November 6, 2015 20:41
-
-
Save LukaHorvat/ed766e1fc7f82678627a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <unistd.h> | |
#include <pthread.h> | |
#include <string.h> | |
#define DEBUG | |
#ifdef DEBUG | |
#define PRINT(...) printf(__VA_ARGS__); | |
#else | |
#define PRINT(...) do {} while (0); | |
#endif | |
extern void dger_(int *m, int *n, double *alpha, double *x, int *incX, double *y, int *incY, double* a, int *lda); | |
typedef struct { | |
int idx; | |
int m; | |
int n; | |
int p; | |
int t; | |
pthread_barrier_t *bar; | |
pthread_t *thrs; | |
double *a; | |
double *b; | |
double *buff; | |
} thr_mult_struct; | |
void sum(double *a, double *b, int size) { | |
for (size_t i = 0; i < size; i++) { | |
a[i] += b[i]; | |
b[i] = 0; | |
} | |
} | |
void *thr_mult(void *voidStr) { | |
thr_mult_struct *str = (thr_mult_struct*)voidStr; | |
double *buff = str->buff + str->m * str->p * str->idx; | |
PRINT("Created thread %d\n", str->idx); | |
for (size_t i = 0; i < str->n; i++) { | |
// Thread i should do the multiplication for | |
// every t-th row and column vector after the i-th one | |
if (i % str->t == str->idx) { | |
PRINT("Doing multiplication at %d on thread %d\n", i, str->idx); | |
double alpha = 1; | |
int incX = 1; | |
dger_(&str->m, &str->p, &alpha, str->a + i * str->p, &incX, str->b + i, &str->n, buff, &str->m); | |
} | |
} | |
// Wait till all the threads are done with multiplication | |
pthread_barrier_wait(str->bar); | |
PRINT("Thread %d passed barrier\n", str->idx); | |
for (size_t stage = 2; stage / 2 < str->t; stage *= 2) { | |
// Join up correct threads and add their results together | |
int idxToJoin = str->idx + stage / 2; | |
if (str->idx % stage == 0 && idxToJoin < str->t) { | |
pthread_join(str->thrs[idxToJoin], NULL); | |
PRINT("Summing result of thread %d and thread %d\n", str->idx, idxToJoin); | |
sum(buff, str->buff + str->m * str->p * idxToJoin, str->m * str->p); | |
} | |
} | |
} | |
double *mult(int t, int m, int n, int p, double *a, double *b) { | |
// Every thread will get it's own matrix so it can do dger_s without conflicts. | |
double *tempBuffer = (double*)calloc(m * p * t, sizeof(double)); | |
thr_mult_struct thr_mult_structs[t]; | |
pthread_t thr_indexes[t]; | |
pthread_barrier_t bar; | |
pthread_barrier_init(&bar, NULL, t); | |
for (size_t i = 0; i < t; i++) { | |
thr_mult_structs[i].idx = i; | |
thr_mult_structs[i].m = m; | |
thr_mult_structs[i].n = n; | |
thr_mult_structs[i].p = p; | |
thr_mult_structs[i].t = t; | |
thr_mult_structs[i].bar = &bar; | |
thr_mult_structs[i].thrs = thr_indexes; | |
thr_mult_structs[i].a = a; | |
thr_mult_structs[i].b = b; | |
thr_mult_structs[i].buff = tempBuffer; | |
pthread_create(&thr_indexes[i], NULL, thr_mult, (void*)&thr_mult_structs[i]); | |
} | |
// Finally, join up with the first thread. It's buffer has the final result. | |
pthread_join(thr_indexes[0], NULL); | |
PRINT("All collected\n"); | |
double *result = (double*)malloc(sizeof(double) * m * p); | |
memcpy(result, tempBuffer, sizeof(double) * m * p); | |
free(tempBuffer); | |
return result; | |
} | |
double *readMatrix(const char *file, int numElems) { | |
FILE *f = fopen(file, "rb"); | |
double *res = malloc(sizeof(double) * numElems); | |
fread(res, sizeof(double), numElems, f); | |
fclose(f); | |
return res; | |
} | |
void writeMatrix(const char *file, double* matrix, int numElems) { | |
FILE *f = fopen(file, "wb"); | |
fwrite(matrix, sizeof(double), numElems, f); | |
fclose(f); | |
} | |
int main(int argc, char const *argv[]) { | |
setbuf(stdout, NULL); | |
int t = atoi(argv[1]); | |
int m = atoi(argv[2]); | |
int n = atoi(argv[3]); | |
int p = atoi(argv[4]); | |
double *a = readMatrix(argv[5], m * n); | |
double *b = readMatrix(argv[6], n * p); | |
double *c = mult(t, m, n, p, a, b); | |
for (size_t row = 0; row < m; row++) { | |
for (size_t col = 0; col < p; col++) { | |
// PRINT("%lf ", *(c + row + col * m)); | |
} | |
// PRINT("\n"); | |
} | |
writeMatrix(argv[7], c, m * p); | |
free(c); | |
free(b); | |
free(a); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment