Skip to content

Instantly share code, notes, and snippets.

@fatemehkarimi
Last active March 10, 2018 18:03
Show Gist options
  • Save fatemehkarimi/83e8292e32592655d51156222d64753d to your computer and use it in GitHub Desktop.
Save fatemehkarimi/83e8292e32592655d51156222d64753d to your computer and use it in GitHub Desktop.
C++- matrix multiplication(Strassen's algorithm)
//Strassen's algorithm for square matrix multiplication-page 79 of introduction
//to algorithms
//NOTE: ONLY WORKS FOR SQUARE MATRICES FOR WHICH THEIR SIZE IS A POWER OF 2.
#include <iostream>
using namespace std;
//function prototypes
int ** matrix_addition(int ** A, int ** B, int rowA, int colA, int rowB, int colB, int size);
int ** matrix_subtraction(int ** A, int ** B, int rowA, int colA, int rowB, int colB, int size);
int ** matirx_multiply(int ** A, int ** B, int rowA, int colA, int rowB, int colB, int size);
int main(void)
{
cout << "enter size of matrices. the size should be a power of 2. if not, unpredictable result." << endl;
int size = 0;
cin >> size;
int ** matrix1 = new int *[size];
*matrix1 = new int[size * size];
int ** matrix2 = new int *[size];
*matrix2 = new int[size * size];
for (int i = 0; i < size; ++i){
matrix1[i] = matrix1[0] + i * size;
matrix2[i] = matrix2[0] + i * size;
}
cout << "enter matrix A elements" << endl;
for (int i = 0; i < size; ++i)
for (int j = 0; j < size; ++j)
cin >> matrix1[i][j];
cout << "enter matrix B elements" << endl;
for (int i = 0; i < size; ++i)
for (int j = 0; j < size; ++j)
cin >> matrix2[i][j];
int ** tmp = matirx_multiply(matrix1, matrix2, 0, 0, 0, 0, size);
cout << endl << "the result of matrix multiplication is:" << endl;
for (int i = 0; i < size; ++i) {
for (int j = 0; j < size; ++j)
cout << tmp[i][j] << ' ';
cout << endl;
}
cout << endl;
return 0;
}
int ** matirx_multiply(int ** A, int ** B, int rowA, int colA, int rowB, int colB, int size)
{
if (size == 1){
int ** result = new int * [1];
*result = new int[1];
result[0][0] = A[rowA][colA] * B[rowB][colB];
return result;
}
int middle = size / 2;
int ** S1 = matrix_subtraction(B, B, rowB, colB + middle, rowB + middle, colB + middle, size / 2);
int ** S2 = matrix_addition(A, A, rowA, colA, rowA, colA + middle, size / 2);
int ** S3 = matrix_addition(A, A, rowA + middle, colA, rowA + middle, rowA + middle, size / 2);
int ** S4 = matrix_subtraction(B, B, rowB + middle, colB, rowB, colB, size / 2);
int ** S5 = matrix_addition(A, A, rowA, colA, rowA + middle, colA + middle, size / 2);
int ** S6 = matrix_addition(B, B, rowB, colB, rowB + middle, colB + middle, size / 2);
int ** S7 = matrix_subtraction(A, A, rowA, colA + middle, rowA + middle, colA + middle, size / 2);
int ** S8 = matrix_addition(B, B, rowB + middle, colB, rowB + middle, colB + middle, size / 2);
int ** S9 = matrix_subtraction(A, A, rowA, colA, rowA + middle, colA, size / 2);
int ** S10 = matrix_addition(B, B, rowB, colB, rowB, colB + middle, size / 2);
int ** P1 = matirx_multiply(A, S1, rowA, colA, 0, 0, size / 2);
int ** P2 = matirx_multiply(S2, B, 0, 0, rowB + middle, colB + middle, size / 2);
int ** P3 = matirx_multiply(S3, B, 0, 0, rowB, colB, size / 2);
int ** P4 = matirx_multiply(A, S4, rowA + middle, rowA + middle, 0, 0, size / 2);
int ** P5 = matirx_multiply(S5, S6, 0, 0, 0, 0, size / 2);
int ** P6 = matirx_multiply(S7, S8, 0, 0, 0, 0, size / 2);
int ** P7 = matirx_multiply(S9, S10, 0, 0, 0, 0, size / 2);
int ** sum1 = matrix_addition(P5, P4, 0, 0, 0, 0, size / 2);
int ** sum2 = matrix_addition(sum1, P6, 0, 0, 0, 0, size / 2);
int ** C11 = matrix_subtraction(sum2, P2, 0, 0, 0, 0, size / 2);
int ** C12 = matrix_addition(P1, P2, 0, 0, 0, 0, size / 2);
int ** C21 = matrix_addition(P3, P4, 0, 0, 0, 0, size / 2);
int ** sum3 = matrix_addition(P5, P1, 0, 0, 0, 0, size / 2);
int ** sum4 = matrix_addition(P3, P7, 0, 0, 0, 0, size / 2);
int ** C22 = matrix_subtraction(sum3, sum4, 0, 0, 0, 0, size / 2);
//creating result matrix
int ** result = new int *[size];
*result = new int[size * size];
for (int i = 0; i < size; ++i)
result[i] = result[0] + i * size;
for (int i = 0; i < size / 2; ++i)
for (int j = 0; j < size / 2; ++j)
result[i][j] = C11[i][j];
for (int i = 0; i < size / 2; ++i)
for (int j = size / 2; j < size; ++j)
result[i][j] = C12[i][j - size / 2];
for (int i = size / 2; i < size; ++i)
for (int j = 0; j < size / 2; ++j)
result[i][j] = C21[i - size / 2][j];
for (int i = size / 2; i < size; ++i)
for (int j = size / 2; j < size; ++j)
result[i][j] = C22[i - size / 2][j - size / 2];
delete[] S1;
delete[] S2;
delete[] S3;
delete[] S4;
delete[] S5;
delete[] S6;
delete[] S7;
delete[] S8;
delete[] S9;
delete[] S10;
S1 = S2 = S3 = S4 = S5 = S6 = S7 = S8 = S9 = S10 = nullptr;
delete[] P1;
delete[] P2;
delete[] P3;
delete[] P4;
delete[] P5;
delete[] P6;
delete[] P7;
P1 = P2 = P3 = P4 = P5 = P6 = P7 = nullptr;
delete[] sum1;
delete[] sum2;
delete[] sum3;
delete[] sum4;
sum1 = sum2 = sum3 = sum4 = nullptr;
delete[] C11;
delete[] C12;
delete[] C21;
delete[] C22;
C11 = C12 = C21 = C22 = nullptr;
return result;
}
int ** matrix_addition(int ** A, int ** B, int rowA, int colA, int rowB, int colB, int size)
{
int ** result = new int *[size];
*result = new int[size * size];
for (int i = 0; i < size; ++i)
result[i] = result[0] + i * size;
for (int i = 0; i < size; ++i)
for (int j = 0; j < size; ++j)
result[i][j] = A[rowA + i][colA + j] + B[rowB + i][colB + j];
return result;
}
int ** matrix_subtraction(int ** A, int ** B, int rowA, int colA, int rowB, int colB, int size)
{
int ** result = new int *[size];
*result = new int[size * size];
for (int i = 0; i < size; ++i)
result[i] = result[0] + i * size;
for (int i = 0; i < size; ++i)
for (int j = 0; j < size; ++j)
result[i][j] = A[rowA + i][colA + j] - B[rowB + i][colB + j];
return result;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment