Created
October 24, 2012 14:30
-
-
Save diegode/3946395 to your computer and use it in GitHub Desktop.
Strassen algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <string.h> | |
#include <stdlib.h> | |
#include <assert.h> | |
#include <omp.h> // OpenMP | |
#include "matrix.h" | |
#define MIN_STRASSEN (64*64) | |
#define OMP_DYNAMIC TRUE | |
matrix_t matrix_alloc(int h, int w) { | |
matrix_t m; | |
m.size1 = h; | |
m.size2 = w; | |
m.tda = w+1+!(w&1); | |
posix_memalign((void **)&m.data, 16, (h+1)*m.tda*sizeof(real_t)); | |
return m; | |
} | |
void matrix_free(matrix_t m) { | |
free(m.data); | |
} | |
void matrix_add(matrix_t c, matrix_t a, matrix_t b) { | |
int i,j; | |
for(i=0; i < a.size1; i++) | |
for(j=0; j < a.size2; j++) | |
matrix(c,i,j) = matrix(a,i,j)+matrix(b,i,j); | |
} | |
void matrix_sub(matrix_t c, matrix_t a, matrix_t b) { | |
int i,j; | |
for(i=0; i < a.size1; i++) | |
for(j=0; j < a.size2; j++) | |
matrix(c,i,j) = matrix(a,i,j)-matrix(b,i,j); | |
} | |
void matrix_mul_simple(matrix_t c, matrix_t a, matrix_t b) { | |
int i,j,k; | |
real_t x; | |
for(i=0; i < a.size1; i++) { | |
for(j=0; j < b.size2; j++) { | |
x = 0; | |
for(k=0; k < a.size2; k++) | |
x += matrix(a,i,k) * matrix(b,k,j); | |
matrix(c,i,j) = x; | |
} | |
} | |
} | |
void submatrices(matrix_t a, matrix_t *a11, matrix_t *a12, matrix_t *a21, matrix_t *a22) { | |
int hh = a.size1 >> 1; | |
a11->size1 = a12->size1 = hh; | |
a21->size1 = a22->size1 = hh; //a.size1-hh; | |
int wh = a.size2 >> 1; | |
a11->size2 = a21->size2 = wh; | |
a12->size2 = a22->size2 = wh; //a.size2-wh; | |
a11->tda = a12->tda = a21->tda = a22->tda = a.tda; | |
a11->data = &matrix(a,0,0); | |
a12->data = &matrix(a,0,wh); | |
a21->data = &matrix(a,hh,0); | |
a22->data = &matrix(a,hh,wh); | |
} | |
// P = (A1 + A2) * (B1 + B2) | |
void strassen_1(matrix_t a1, matrix_t a2, matrix_t b1, matrix_t b2, matrix_t p) { | |
matrix_t f = matrix_alloc(a1.size1, a1.size2); | |
matrix_t s = matrix_alloc(b1.size1, b1.size2); | |
#pragma omp parallel sections | |
{ | |
#pragma omp section | |
matrix_add(f,a1,a2); | |
#pragma omp section | |
matrix_add(s,b1,b2); | |
} | |
matrix_mul(p,f,s); | |
matrix_free(f); | |
matrix_free(s); | |
} | |
// P = (A1 + A2) * B | |
void strassen_2(matrix_t a1, matrix_t a2, matrix_t b, matrix_t p) { | |
matrix_t f = matrix_alloc(a1.size1, a1.size2); | |
matrix_add(f,a1,a2); | |
matrix_mul(p,f,b); | |
matrix_free(f); | |
} | |
// P = A * (B1 - B2) | |
void strassen_3(matrix_t a, matrix_t b1, matrix_t b2, matrix_t p) { | |
matrix_t s = matrix_alloc(b1.size1, b1.size2); | |
matrix_sub(s,b1,b2); | |
matrix_mul(p,a,s); | |
matrix_free(s); | |
} | |
// P = (A1 - A2) * (B1 + B2) | |
void strassen_4(matrix_t a1, matrix_t a2, matrix_t b1, matrix_t b2, matrix_t p) { | |
matrix_t f = matrix_alloc(a1.size1, a1.size2); | |
matrix_t s = matrix_alloc(b1.size1, b1.size2); | |
#pragma omp parallel sections | |
{ | |
#pragma omp section | |
matrix_sub(f,a1,a2); | |
#pragma omp section | |
matrix_add(s,b1,b2); | |
} | |
matrix_mul(p,f,s); | |
matrix_free(f); | |
matrix_free(s); | |
} | |
// P1 += P2 + (P3 - P4) | |
void strassen_5(matrix_t p1, matrix_t p2, matrix_t p3, matrix_t p4) { | |
matrix_t f = matrix_alloc(p1.size1, p1.size2); | |
#pragma omp parallel sections | |
{ | |
#pragma omp section | |
matrix_add(p1,p1,p2); | |
#pragma omp section | |
matrix_sub(f,p3,p4); | |
} | |
matrix_add(p1,p1,f); | |
matrix_free(f); | |
} | |
void matrix_mul_strassen(matrix_t c, matrix_t a, matrix_t b) { | |
matrix_t p1,p2,p3,c11,c12,c21,c22; | |
matrix_t a11,a12,a21,a22,b11,b12,b21,b22; | |
int hp = a.size1 >> 1; | |
int wp = b.size2 >> 1; | |
p1 = matrix_alloc(hp,wp); | |
p2 = matrix_alloc(hp,wp); | |
p3 = matrix_alloc(hp,wp); | |
submatrices(a,&a11,&a12,&a21,&a22); | |
submatrices(b,&b11,&b12,&b21,&b22); | |
submatrices(c,&c11,&c12,&c21,&c22); | |
#pragma omp parallel sections | |
{ | |
// P = (A1 + A2) * (B1 + B2) | |
#pragma omp section | |
strassen_1(a11,a22,b11,b22,p1); | |
// P = (A1 + A2) * B | |
#pragma omp section | |
strassen_2(a21,a22,b11,c21); | |
#pragma omp section | |
strassen_2(a11,a12,b22,c12); | |
// P = A * (B1 - B2) | |
#pragma omp section | |
strassen_3(a11,b12,b22,p2); | |
#pragma omp section | |
strassen_3(a22,b21,b11,p3); | |
// P = (A1 - A2) * (B1 + B2) | |
#pragma omp section | |
strassen_4(a21,a11,b11,b12,c22); | |
#pragma omp section | |
strassen_4(a12,a22,b21,b22,c11); | |
} | |
#pragma omp parallel sections | |
{ | |
// P1 += P2 + (P3 - P4) | |
#pragma omp section | |
strassen_5(c11,p1,p3,c12); | |
#pragma omp section | |
strassen_5(c22,p1,p2,c21); | |
} | |
#pragma omp parallel sections | |
{ | |
// P1 += P2 | |
#pragma omp section | |
matrix_add(c12,c12,p2); | |
#pragma omp section | |
matrix_add(c21,c21,p3); | |
} | |
int i,j,k; | |
real_t x; | |
#pragma omp parallel sections | |
{ | |
#pragma omp section | |
if(a.size1 & 1) { | |
// (ultima fila de C) = (ultima fila de A) * B | |
k = a.size1-1; | |
for(j=0; j<b.size2; j++) { | |
x = 0; | |
for(i=0; i<b.size1; i++) | |
x += matrix(a,k,i) * matrix(b,i,j); | |
matrix(c,k,j) = x; | |
} | |
} | |
#pragma omp section | |
if(b.size2 & 1) { | |
// (ultima col de C) = A * (ultima col de B) | |
k = b.size2-1; | |
for(i=0; i<a.size1; i++) { | |
x = 0; | |
for(j=0; j<a.size2; j++) | |
x += matrix(a,i,j) * matrix(b,j,k); | |
matrix(c,i,k) = x; | |
} | |
} | |
} | |
if(a.size2 & 1) { | |
// C += (ultima col de A) * (ultima fila de B) | |
k = a.size2-1; | |
for(i=0; i<(c.size1&~1); i++) { | |
for(j=0; j<(c.size2&~1); j++) | |
matrix(c,i,j) += matrix(a,i,k) * matrix(b,k,j); | |
} | |
} | |
matrix_free(p1); | |
matrix_free(p2); | |
matrix_free(p3); | |
} | |
void matrix_mul(matrix_t c, matrix_t a, matrix_t b) { | |
assert(a.size2==b.size1 && c.size1==a.size1 && c.size2==b.size2); | |
if(c.size1*c.size2 < MIN_STRASSEN) | |
matrix_mul_simple(c,a,b); | |
else | |
matrix_mul_strassen(c,a,b); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
typedef double real_t; | |
struct matrix { | |
int size1, size2, tda; | |
real_t *data; | |
}; | |
typedef struct matrix matrix_t; | |
#define matrix(m,i,j) m.data[(i)*m.tda + j] | |
matrix_t matrix_alloc(int h, int w); | |
void matrix_free(matrix_t m); | |
void matrix_mul(matrix_t c, matrix_t a, matrix_t b); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment