/** A simple parallel matrix-multiplication program based on Strassens
 *  algorithm. Generates 2 n*n random dense matrices where n=2k and
 *  multiplies them together.
 *
 *  Written by: Alex Stachnik <stachnik@udel.edu>
 *
 *  This program is open source, no rights reserved.
 *
 *  LinBox is free software released under the LGPL; I do not own it.
*/

#include "linbox/matrix/sparse-matrix.h"
#include "linbox/field/modular.h"
#include "linbox/matrix/dense-matrix.h"
#include "linbox/matrix/matrix-domain.h"

#include <iostream>

using namespace LinBox;

typedef Modular<double> Field;
typedef typename Field::Element Element;
typedef BlasMatrix<Field> Matrix;
typedef typename MatrixDomain<Field>::Matrix Submat;

int randRange(int start, int end)
{
        double rval = rand();
        static const double NORMALIZING_CONSTANT = 1.0/(1.0+RAND_MAX);
        double normedRVal = rval*NORMALIZING_CONSTANT;
        double rangeSize = end-start;
        int offset = (int)(rangeSize*normedRVal);
        return start+offset;
}

void strassen(Matrix& C, Matrix& A, Matrix& B)
{
  size_t p,n,halfN;
  MatrixDomain<Field> MD(A.field());

  n = A.coldim();
  halfN=n/2;
  MD.subin(C,C);

#pragma omp parallel
#pragma omp for schedule (static,1)
  for (p=0;p<7;++p) {
    Matrix M(A.field(),halfN,halfN);
    Matrix temp1(A.field(),halfN,halfN);
    Matrix temp2(A.field(),halfN,halfN);
    Matrix temp3(A.field(),halfN,halfN);
    Submat Asub1,Asub2,Bsub1,Bsub2;
    Submat Csub;

    switch(p) {
    case 0:
      Asub1.submatrix(A,0,0,halfN,halfN);
      Asub2.submatrix(A,halfN,halfN,halfN,halfN);
      Bsub1.submatrix(B,0,0,halfN,halfN);
      Bsub2.submatrix(B,halfN,halfN,halfN,halfN);

      MD.add(temp1,Asub1,Asub2);
      MD.add(temp2,Bsub1,Bsub2);
      MD.mul(M,temp1,temp2);
      break;

    case 1:
      Asub1.submatrix(A,halfN,0,halfN,halfN);
      Asub2.submatrix(A,halfN,halfN,halfN,halfN);
      Bsub1.submatrix(B,0,0,halfN,halfN);

      MD.add(temp1,Asub1,Asub2);
      MD.mul(M,temp1,Bsub1);
      break;

    case 2:
      Asub1.submatrix(A,0,0,halfN,halfN);
      Bsub1.submatrix(B,0,halfN,halfN,halfN);
      Bsub2.submatrix(B,halfN,halfN,halfN,halfN);

      MD.sub(temp1,Bsub1,Bsub2);
      MD.mul(M,Asub1,temp1);
      break;

    case 3:
      Asub1.submatrix(A,halfN,halfN,halfN,halfN);
      Bsub1.submatrix(B,halfN,0,halfN,halfN);
      Bsub2.submatrix(B,0,0,halfN,halfN);

      MD.sub(temp1,Bsub1,Bsub2);
      MD.mul(M,Asub1,temp1);
      break;

    case 4:
      Asub1.submatrix(A,0,0,halfN,halfN);
      Asub2.submatrix(A,0,halfN,halfN,halfN);
      Bsub1.submatrix(B,halfN,halfN,halfN,halfN);

      MD.add(temp1,Asub1,Asub2);
      MD.mul(M,temp1,Bsub1);
      break;

    case 5:
      Asub1.submatrix(A,halfN,0,halfN,halfN);
      Asub2.submatrix(A,0,0,halfN,halfN);
      Bsub1.submatrix(B,0,0,halfN,halfN);
      Bsub2.submatrix(B,0,halfN,halfN,halfN);

      MD.sub(temp1,Asub1,Asub2);
      MD.add(temp2,Bsub1,Bsub2);
      MD.mul(M,temp1,temp2);
      break;

    case 6:
      Asub1.submatrix(A,0,halfN,halfN,halfN);
      Asub2.submatrix(A,halfN,halfN,halfN,halfN);
      Bsub1.submatrix(B,halfN,0,halfN,halfN);
      Bsub2.submatrix(B,halfN,halfN,halfN,halfN);

      MD.sub(temp1,Asub1,Asub2);
      MD.add(temp2,Bsub1,Bsub2);
      MD.mul(M,temp1,temp2);
      break;
default:
      break;
    }

#pragma omp critical
    {
      Csub.submatrix(C,0,0,halfN,halfN);
      if (p==0 || p==3 || p==6) {
        MD.addin(Csub,M);
      }
      if (p==4) {
        MD.subin(Csub,M);
      }
    }

#pragma omp critical
    {
      Csub.submatrix(C,0,halfN,halfN,halfN);
      if (p==2 || p==4) {
        MD.addin(Csub,M);
      }
    }

#pragma omp critical
    {
      Csub.submatrix(C,halfN,0,halfN,halfN);
      if (p==1 || p==3) {
        MD.addin(Csub,M);
      }
    }

#pragma omp critical
    {
      Csub.submatrix(C,halfN,halfN,halfN,halfN);
      if (p==0 || p==2 || p==5) {
        MD.addin(Csub,M);
      }
      if (p==1) {
        MD.subin(Csub,M);
      }
    }

  }
}

int main(int argc, char** argv)
{
  size_t q=7,n=128;
  Field F(q);
  Element d;
  MatrixDomain<Field> MD(F);
  Matrix A(F,n,n),B(F,n,n),C(F,n,n),D(F,n,n);

  for (size_t i=0;i<n;++i) {
    for (size_t j=0;j<n;++j) {
      F.init(d,randRange(0,q));
      A.setEntry(i,j,d);
      F.init(d,randRange(0,q));
      B.setEntry(i,j,d);
    }
  }

  strassen(C,A,B);

  MD.mul(D,A,B);

  if (MD.areEqual(C,D)) {
    std::cout << "Are equal" << std::endl;
  } else {
    std::cout << "Not equal" << std::endl;
  }

  return 0;
}