Skip to content

Instantly share code, notes, and snippets.

@sid24rane
Created February 4, 2017 16:11
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sid24rane/bdd8f579979fd79f98d7bb481ce79892 to your computer and use it in GitHub Desktop.
Save sid24rane/bdd8f579979fd79f98d7bb481ce79892 to your computer and use it in GitHub Desktop.
Strassen Matrix Multiplication in Java - NxN matrix
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
public class Codechef {
private static int[][] StrassenMultiply(int[][] A, int[][] B) {
int n = A.length;
int[][] res = new int[n][n];
// if the input matrix is 1x1
if (n == 1) {
res[0][0] = A[0][0] * B[0][0];
} else {
// first matrix
int[][] a = new int[n / 2][n / 2];
int[][] b = new int[n / 2][n / 2];
int[][] c = new int[n / 2][n / 2];
int[][] d = new int[n / 2][n / 2];
// second matrix
int[][] e = new int[n / 2][n / 2];
int[][] f = new int[n / 2][n / 2];
int[][] g = new int[n / 2][n / 2];
int[][] h = new int[n / 2][n / 2];
// dividing matrix A into 4 parts
divideArray(A, a, 0, 0);
divideArray(A, b, 0, n / 2);
divideArray(A, c, n / 2, 0);
divideArray(A, d, n / 2, n / 2);
// dividing matrix B into 4 parts
divideArray(B, e, 0, 0);
divideArray(B, f, 0, n / 2);
divideArray(B, g, n / 2, 0);
divideArray(B, h, n / 2, n / 2);
/**
p1 = (a + d)(e + h)
p2 = (c + d)e
p3 = a(f - h)
p4 = d(g - e)
p5 = (a + b)h
p6 = (c - a) (e + f)
p7 = (b - d) (g + h)
**/
int[][] p1 = StrassenMultiply(addMatrices(a, d), addMatrices(e, h));
int[][] p2 = StrassenMultiply(addMatrices(c,d),e);
int[][] p3 = StrassenMultiply(a, subMatrices(f, h));
int[][] p4 = StrassenMultiply(d, subMatrices(g, e));
int[][] p5 = StrassenMultiply(addMatrices(a,b), h);
int[][] p6 = StrassenMultiply(subMatrices(c, a), addMatrices(e, f));
int[][] p7 = StrassenMultiply(subMatrices(b, d), addMatrices(g, h));
/**
C11 = p1 + p4 - p5 + p7
C12 = p3 + p5
C21 = p2 + p4
C22 = p1 - p2 + p3 + p6
**/
int[][] C11 = addMatrices(subMatrices(addMatrices(p1, p4), p5), p7);
int[][] C12 = addMatrices(p3, p5);
int[][] C21 = addMatrices(p2, p4);
int[][] C22 = addMatrices(subMatrices(addMatrices(p1, p3), p2), p6);
// adding all subarray back into one
copySubArray(C11, res, 0, 0);
copySubArray(C12, res, 0, n / 2);
copySubArray(C21, res, n / 2, 0);
copySubArray(C22, res, n / 2, n / 2);
}
return res;
}
public static void main(String[] args) throws NumberFormatException, IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
System.out.println("Enter the order of Matrices:");
int order = Integer.parseInt(br.readLine());
int[][] A = new int[order][order];
int[][] B = new int[order][order];
int[][] res = new int[order][order];
// input : first matrix
System.out.println("Enter first matrix:");
for (int i = 0; i < order; i++) {
for (int j = 0; j < order; j++) {
A[i][j] = Integer.parseInt(br.readLine());
}
}
// input : second matrix
System.out.println("Enter second matrix:");
for (int i = 0; i < order; i++) {
for (int j = 0; j < order; j++) {
B[i][j] = Integer.parseInt(br.readLine());
}
}
res = StrassenMultiply(A, B);
printMatrix(res);
}
// helper methods
// Adding 2 matrices
public static int[][] addMatrices(int[][] a, int[][] b) {
int n = a.length;
int[][] res = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
res[i][j] = a[i][j] + b[i][j];
}
}
return res;
}
// Subtracting 2 matrices
public static int[][] subMatrices(int[][] a, int[][] b) {
int n = a.length;
int[][] res = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
res[i][j] = a[i][j] - b[i][j];
}
}
return res;
}
// print matrix
public static void printMatrix(int[][] a) {
int n = a.length;
System.out.println("Resultant Matrix ");
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
System.out.print(a[i][j] + "\t");
}
System.out.println();
}
System.out.println();
}
// divides array
public static void divideArray(int[][] P, int[][] C, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
C[i1][j1] = P[i2][j2];
}
// copies
public static void copySubArray(int[][] C, int[][] P, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
P[i2][j2] = C[i1][j1];
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment