Skip to content

Instantly share code, notes, and snippets.

@redknightlois
Last active August 29, 2015 14:02
Show Gist options
  • Save redknightlois/5aaa75b7125651d628b1 to your computer and use it in GitHub Desktop.
Save redknightlois/5aaa75b7125651d628b1 to your computer and use it in GitHub Desktop.
A real Matrix<T> for int, float and double types.
public class Matrix<T> : IEquatable<Matrix<T>>
where T : struct
{
public readonly int Rows;
public readonly int Columns;
/// <summary>
/// Data is stored in Column-Major to be compliant with the BLAS packages
/// </summary>
private readonly T[,] Data;
private Lazy<Decomposition<T>> decomposition;
public Decomposition<T> Decomposition
{
get { return decomposition.Value; }
}
public static readonly T Zero;
public static readonly T One;
public static readonly T DefaultEpsilon;
public static readonly T Epsilon;
static Matrix ()
{
if (typeof(T) == typeof(int))
{
Zero = (T)(object)0;
One = (T)(object)1;
DefaultEpsilon = (T)(object)0;
}
else if (typeof(T) == typeof(float))
{
Zero = (T)(object)0f;
One = (T)(object)1f;
DefaultEpsilon = (T)(object)0.00001f;
}
else if (typeof(T) == typeof(double))
{
Zero = (T)(object)0d;
One = (T)(object)1d;
DefaultEpsilon = (T)(object)0.00001d;
}
else
{
Zero = default(T);
One = default(T);
DefaultEpsilon = default(T);
}
Epsilon = DefaultEpsilon;
}
public Matrix(int iRows, int iCols) // Matrix Class constructor
{
Contract.Requires<NotSupportedException>(typeof(T) == typeof(int) || typeof(T) == typeof(float) || typeof(T) == typeof(double));
Contract.Requires<ArgumentException>(iRows > 0 && iCols > 0);
Rows = iRows;
Columns = iCols;
Data = new T[Columns, Rows];
this.decomposition = new Lazy<Decomposition<T>>(() => MakeLU(this), true);
}
public Matrix(int iRows, int iCols, T value) // Matrix Class constructor
{
Contract.Requires<NotSupportedException>(typeof(T) == typeof(int) || typeof(T) == typeof(float) || typeof(T) == typeof(double));
Contract.Requires<ArgumentException>(iRows > 0 && iCols > 0);
Rows = iRows;
Columns = iCols;
Data = new T[Columns, Rows];
for (int i = 0; i < Columns ; i++)
for (int j = 0; j < Rows; j++)
Data[i, j] = value;
this.decomposition = new Lazy<Decomposition<T>>(() => MakeLU(this), true);
}
public bool IsSquare
{
get { return Rows == Columns; }
}
public T this[int iRow, int iCol] // Access this matrix as a 2D array
{
get { return Data[iCol, iRow]; }
set { Data[iCol, iRow] = value; }
}
public Matrix<T> GetColumn(int k)
{
Matrix<T> m = new Matrix<T>(Rows, 1);
for (int i = 0; i < Rows; i++)
m[i, 0] = this[i, k];
return m;
}
public void SetColumn(Matrix<T> v, int k)
{
for (int i = 0; i < Rows; i++)
this[i, k] = v[i, 0];
}
public Matrix<T> GetRow(int k)
{
Matrix<T> m = new Matrix<T>(1, Columns);
for (int i = 0; i < Columns; i++)
m[0, i] = this[k, i];
return m;
}
public void SetRow(Matrix<T> v, int k)
{
for (int i = 0; i < Columns; i++)
this[k, i] = v[0, i];
}
/// <summary>
/// Function returns the copy of this matrix
/// </summary>
public Matrix<T> Clone()
{
Matrix<T> matrix = new Matrix<T>(Rows, Columns);
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
matrix[i, j] = this[i, j];
return matrix;
}
/// <summary>
/// Creates a zero matrix
/// </summary>
/// <param name="iRows"></param>
/// <param name="iCols"></param>
/// <returns></returns>
public static Matrix<T> Zeroes(int iRows, int iCols)
{
Contract.Requires<ArgumentException>(iRows > 0 && iCols > 0);
return new Matrix<T>(iRows, iCols);
}
/// <summary>
/// Creates an identity matrix.
/// </summary>
public static Matrix<T> Identity(int iRows, int iCols)
{
Contract.Requires<ArgumentException>(iRows > 0 && iCols > 0);
var matrix = new Matrix<T>(iRows, iCols);
for (int i = 0; i < Math.Min(iRows, iCols); i++)
matrix[i, i] = Matrix<T>.One;
return matrix;
}
public string ToString(string format, IFormatProvider formatProvider)
{
string s = "";
for (int i = 0; i < Rows; i++)
{
for (int j = 0; j < Columns; j++) s += String.Format("{0,5:0.00}", this[i, j]) + " ";
s += "\r\n";
}
return s;
}
public string ToString(string format)
{
return this.ToString(format, CultureInfo.CurrentCulture);
}
public override string ToString() // Function returns matrix as a string
{
return this.ToString("G", CultureInfo.CurrentCulture);
}
/// <summary>
/// Matrix transpose, for any rectangular matrix
/// </summary>
public static Matrix<T> Transpose(Matrix<T> m)
{
Contract.Requires<ArgumentNullException>(m != null);
Matrix<T> t = new Matrix<T>(m.Columns, m.Rows);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
t[j, i] = m[i, j];
return t;
}
/// <summary>
/// Matrix transpose, for any rectangular matrix
/// </summary>
public Matrix<T> Transpose()
{
return Matrix<T>.Transpose(this);
}
public Matrix<T> Invert()
{
return Matrix<T>.Invert(this);
}
public static Matrix<T> Invert(Matrix<T> m)
{
if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
return MatrixHelper.Invert(t) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
return MatrixHelper.Invert(t) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> Power(Matrix<T> m, int pow) // Power matrix to exponent
{
Contract.Requires<ArgumentNullException>(m != null);
if (pow == 0)
return Identity(m.Rows, m.Columns);
if (pow == 1)
return m.Clone();
if (pow == -1)
return m.Invert();
Matrix<T> x;
if (pow < 0)
{
x = m.Invert();
pow *= -1;
}
else x = m.Clone();
Matrix<T> ret = Identity(m.Rows, m.Columns);
while (pow != 0)
{
if ((pow & 1) == 1)
ret *= x;
x *= x;
pow >>= 1;
}
return ret;
}
public Matrix<T> Power(int pow)
{
return Matrix<T>.Power(this, pow);
}
/// <summary>
/// Function returns permutation matrix "P" due to permutation vector "pi"
/// </summary>
public Matrix<T> Permutation()
{
var pi = decomposition.Value.Permutation;
Matrix<T> matrix = Matrix<T>.Zeroes(Rows, Columns);
for (int i = 0; i < Rows; i++)
matrix[pi[i], i] = Matrix<T>.One;
return matrix;
}
public bool Equals(Matrix<T> other)
{
if (other == null)
return false;
return Equals(this, other, Matrix<T>.Epsilon);
}
public static bool Equals(Matrix<T> m1, Matrix<T> m2)
{
return Equals(m1, m2, Matrix<T>.Epsilon);
}
public override int GetHashCode()
{
unchecked // Overflow is fine, just wrap
{
int hash = 17;
// Suitable nullity checks etc, of course :)
hash = hash * 23 + this.Rows.GetHashCode();
hash = hash * 23 + this.Columns.GetHashCode();
hash = hash * 23 + this.Data.GetHashCode();
return hash;
}
}
public static bool Equals(Matrix<T> m1, Matrix<T> m2, T epsilon)
{
if (m1 == null || m2 == null)
return false;
if (m1.Columns != m2.Columns || m1.Rows != m2.Rows)
return false;
if (typeof(T) == typeof(int))
{
var t1 = m1 as Matrix<int>;
var t2 = m2 as Matrix<int>;
var e = (int)(object)epsilon;
return MatrixHelper.Equals(t1, t2, e);
}
else if (typeof(T) == typeof(float))
{
var t1 = m1 as Matrix<float>;
var t2 = m2 as Matrix<float>;
var e = (float)(object)epsilon;
return MatrixHelper.Equals(t1, t2, e);
}
else if (typeof(T) == typeof(double))
{
var t1 = m1 as Matrix<double>;
var t2 = m2 as Matrix<double>;
var e = (double)(object)epsilon;
return MatrixHelper.Equals(t1, t2, e);
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
private static Decomposition<T> MakeLU(Matrix<T> m)
{
Contract.Requires<ArgumentNullException>(m != null);
Contract.Requires<InvalidOperationException>(m.IsSquare, "The matrix is not square!");
if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
return MatrixHelper.LUDecomposition(t) as Decomposition<T>;
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
return MatrixHelper.LUDecomposition(t) as Decomposition<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public T Determinant()
{
return Determinant(this);
}
public static T Determinant(Matrix<T> m)
{
if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
return (T)(object)MatrixHelper.Determinant(t);
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
return (T)(object)MatrixHelper.Determinant(t);
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator *(Matrix<T> m1, Matrix<T> m2)
{
Contract.Requires<ArgumentNullException>(m1 != null && m2 != null);
if (typeof(T) == typeof(int))
{
var t1 = m1 as Matrix<int>;
var t2 = m2 as Matrix<int>;
return MatrixHelper.StrassenMultiply(t1, t2) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t1 = m1 as Matrix<float>;
var t2 = m2 as Matrix<float>;
return MatrixHelper.StrassenMultiply(t1, t2) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t1 = m1 as Matrix<double>;
var t2 = m2 as Matrix<double>;
return MatrixHelper.StrassenMultiply(t1, t2) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator *(T c, Matrix<T> m)
{
Contract.Requires<ArgumentNullException>(m != null);
if (typeof(T) == typeof(int))
{
var t = m as Matrix<int>;
var c1 = (object)c;
return MatrixHelper.Multiply((int)c1, t) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
var c1 = (object)c;
return MatrixHelper.Multiply((float)c1, t) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
var c1 = (object)c;
return MatrixHelper.Multiply((double)c1, t) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator *(Matrix<T> m, T c)
{
Contract.Requires<ArgumentNullException>(m != null);
return c * m;
}
public static Matrix<T> operator -(Matrix<T> m)
{
Contract.Requires<ArgumentNullException>(m != null);
if (typeof(T) == typeof(int))
{
var t = m as Matrix<int>;
return MatrixHelper.Negative(t) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
return MatrixHelper.Negative(t) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
return MatrixHelper.Negative(t) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator +(Matrix<T> m1, Matrix<T> m2)
{
Contract.Requires<ArgumentNullException>(m1 != null && m2 != null);
if (typeof(T) == typeof(int))
{
var t1 = m1 as Matrix<int>;
var t2 = m2 as Matrix<int>;
return MatrixHelper.Add(t1, t2) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t1 = m1 as Matrix<float>;
var t2 = m2 as Matrix<float>;
return MatrixHelper.Add(t1, t2) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t1 = m1 as Matrix<double>;
var t2 = m2 as Matrix<double>;
return MatrixHelper.Add(t1, t2) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator -(Matrix<T> m1, Matrix<T> m2)
{
Contract.Requires<ArgumentNullException>(m1 != null && m2 != null);
if (typeof(T) == typeof(int))
{
var t1 = m1 as Matrix<int>;
var t2 = m2 as Matrix<int>;
return MatrixHelper.Substract(t1, t2) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t1 = m1 as Matrix<float>;
var t2 = m2 as Matrix<float>;
return MatrixHelper.Substract(t1, t2) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t1 = m1 as Matrix<double>;
var t2 = m2 as Matrix<double>;
return MatrixHelper.Substract(t1, t2) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator +(T c, Matrix<T> m)
{
Contract.Requires<ArgumentNullException>(m != null);
if (typeof(T) == typeof(int))
{
var t = m as Matrix<int>;
var c1 = (object)c;
return MatrixHelper.Add((int)c1, t) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
var c1 = (object)c;
return MatrixHelper.Add((float)c1, t) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
var c1 = (object)c;
return MatrixHelper.Add((double)c1, t) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator +(Matrix<T> m, T c)
{
Contract.Requires<ArgumentNullException>(m != null);
return c + m;
}
public static Matrix<T> operator -(T c, Matrix<T> m)
{
Contract.Requires<ArgumentNullException>(m != null);
if (typeof(T) == typeof(int))
{
var t = m as Matrix<int>;
var c1 = (object)c;
return MatrixHelper.Substract((int)c1, t) as Matrix<T>;
}
else if (typeof(T) == typeof(float))
{
var t = m as Matrix<float>;
var c1 = (object)c;
return MatrixHelper.Substract((float)c1, t) as Matrix<T>;
}
else if (typeof(T) == typeof(double))
{
var t = m as Matrix<double>;
var c1 = (object)c;
return MatrixHelper.Substract((double)c1, t) as Matrix<T>;
}
throw new NotSupportedException("Type: {0} is not supported by the Matrix<T> class.");
}
public static Matrix<T> operator -(Matrix<T> m, T c)
{
Contract.Requires<ArgumentNullException>(m != null);
return c - m;
}
}
public class Decomposition<T> where T : struct
{
public Matrix<T> L;
public Matrix<T> U;
public int[] Permutation;
public T DeterminantOfP;
public Decomposition( Matrix<T> l, Matrix<T> u, int[] permutation, T detOfP )
{
Contract.Requires(l != null);
Contract.Requires(u != null);
Contract.Requires(permutation != null);
Contract.Requires(l.Rows == permutation.Length);
Contract.Requires(l.Rows == u.Rows && l.Columns == u.Columns);
this.L = l;
this.U = u;
this.Permutation = permutation;
this.DeterminantOfP = detOfP;
}
}
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
<#
var types = new string[] { "int", "float", "double" };
#>
// Based on code from Ivan Kuckir.
// Original license notice:
/*
Matrix class in C#
Written by Ivan Kuckir (ivan.kuckir@gmail.com, http://blog.ivank.net)
Faculty of Mathematics and Physics
Charles University in Prague
(C) 2010
- updated on 1. 6.2014 - Trimming the string before parsing
- updated on 14.6.2012 - parsing improved. Thanks to Andy!
- updated on 3.10.2012 - there was a terrible bug in LU, SoLE and Inversion. Thanks to Danilo Neves Cruz for reporting that!
This code is distributed under MIT licence.
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
*/
using System;
using System.Diagnostics.Contracts;
namespace CudaLearn
{
internal static class MatrixHelper
{
<# foreach (var type in types )
{ #>
<# if ( type != "int")
{#>
internal static Matrix<<#=type#>> SolveWith( Matrix<<#=type#>> A, Matrix<<#=type#>> v )
{
if (A.Rows != A.Columns)
throw new InvalidOperationException("The matrix is not square!");
if (A.Rows != v.Rows)
throw new InvalidOperationException("Wrong number of results in solution vector!");
var decomposition = A.Decomposition;
var pi = decomposition.Permutation;
Matrix<<#=type#>> b = new Matrix<<#=type#>>(A.Rows, 1);
for (int i = 0; i < A.Rows; i++)
b[i, 0] = v[pi[i], 0]; // switch two items in "v" due to permutation matrix
Matrix<<#=type#>> z = SubsForth(decomposition.L, b);
Matrix<<#=type#>> x = SubsBack(decomposition.U, z);
return x;
}
/// <summary>
/// Function solves Ax = b for A as a lower triangular matrix
/// </summary>
private static Matrix<<#=type#>> SubsForth(Matrix<<#=type#>> A, Matrix<<#=type#>> b) //
{
int n = A.Rows;
Matrix<<#=type#>> x = new Matrix<<#=type#>>(n, 1);
for (int i = 0; i < n; i++)
{
x[i, 0] = b[i, 0];
for (int j = 0; j < i; j++)
x[i, 0] -= A[i, j] * x[j, 0];
x[i, 0] = x[i, 0] / A[i, i];
}
return x;
}
/// <summary>
/// Function solves Ax = b for A as an upper triangular matrix
/// </summary>
private static Matrix<<#=type#>> SubsBack(Matrix<<#=type#>> A, Matrix<<#=type#>> b)
{
int n = A.Rows;
Matrix<<#=type#>> x = new Matrix<<#=type#>>(n, 1);
for (int i = n - 1; i > -1; i--)
{
x[i, 0] = b[i, 0];
for (int j = n - 1; j > i; j--)
x[i, 0] -= A[i, j] * x[j, 0];
x[i, 0] = x[i, 0] / A[i, i];
}
return x;
}
internal static Matrix<<#=type#>> Invert ( Matrix<<#=type#>> m )
{
Matrix<<#=type#>> inv = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
{
Matrix<<#=type#>> Ei = Matrix<<#=type#>>.Zeroes(m.Rows, 1);
Ei[i, 0] = 1;
Matrix<<#=type#>> col = SolveWith(m, Ei);
inv.SetColumn(col, i);
}
return inv;
}
internal static <#=type#> Determinant ( Matrix<<#=type#>> m )
{
var decomposition = m.Decomposition;
var u = decomposition.U;
var det = decomposition.DeterminantOfP;
for (int i = 0; i < m.Rows; i++)
det *= u[i, i];
return det;
}
internal static Decomposition<<#=type#>> LUDecomposition( Matrix<<#=type#>> m )
{
var L = Matrix<<#=type#>>.Identity(m.Rows, m.Columns);
var U = m.Clone();
var pi = new int[m.Rows];
for (int i = 0; i < m.Rows; i++)
pi[i] = i;
<#=type#> detOfP = Matrix<<#=type#>>.One;
<#=type#> p = 0;
<#=type#> pom2;
int k0 = 0;
int pom1 = 0;
for (int k = 0; k < m.Columns - 1; k++)
{
p = 0;
for (int i = k; i < m.Rows; i++) // find the row with the biggest pivot
{
if (Math.Abs(U[i, k]) > p)
{
p = Math.Abs(U[i, k]);
k0 = i;
}
}
if (p == 0) // samé nuly ve sloupci
throw new InvalidOperationException("The matrix is singular!");
pom1 = pi[k]; pi[k] = pi[k0]; pi[k0] = pom1; // switch two rows in permutation matrix
for (int i = 0; i < k; i++)
{
pom2 = L[k, i]; L[k, i] = L[k0, i]; L[k0, i] = pom2;
}
if (k != k0)
detOfP *= -1;
for (int i = 0; i < m.Columns; i++) // Switch rows in U
{
pom2 = U[k, i]; U[k, i] = U[k0, i]; U[k0, i] = pom2;
}
for (int i = k + 1; i < m.Rows; i++)
{
L[i, k] = U[i, k] / U[k, k];
for (int j = k; j < m.Columns; j++)
U[i, j] = U[i, j] - L[i, k] * U[k, j];
}
}
return new Decomposition<<#=type#>> ( L, U, pi, detOfP );
}
<#}#>
internal static bool Equals(Matrix<<#=type#>> m1, Matrix<<#=type#>> m2, <#=type#> epsilon)
{
for (int i = 0; i < m1.Rows; i++)
{
for (int j = 0; j < m1.Columns; j++)
{
<#=type#> value = m1[i, j] - m2[i, j];
value = value < 0 ? -value : value;
if (value > epsilon)
return false;
}
}
return true;
}
internal static Matrix<<#=type#>> Negative(Matrix<<#=type#>> m)
{
Matrix<<#=type#>> r = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
r[i, j] = -m[i, j];
return r;
}
internal static Matrix<<#=type#>> Add(Matrix<<#=type#>> n, Matrix<<#=type#>> m)
{
Matrix<<#=type#>> r = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
r[i, j] = m[i, j] + n[i, j];
return r;
}
internal static Matrix<<#=type#>> Substract(Matrix<<#=type#>> n, Matrix<<#=type#>> m)
{
Matrix<<#=type#>> r = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
r[i, j] = m[i, j] - n[i, j];
return r;
}
internal static Matrix<<#=type#>> Add(<#=type#> n, Matrix<<#=type#>> m)
{
Matrix<<#=type#>> r = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
r[i, j] = m[i, j] + n;
return r;
}
internal static Matrix<<#=type#>> Substract(<#=type#> n, Matrix<<#=type#>> m)
{
Matrix<<#=type#>> r = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
r[i, j] = m[i, j] - n;
return r;
}
internal static Matrix<<#=type#>> Multiply(<#=type#> n, Matrix<<#=type#>> m) // Multiplication by constant n
{
Matrix<<#=type#>> r = new Matrix<<#=type#>>(m.Rows, m.Columns);
for (int i = 0; i < m.Rows; i++)
for (int j = 0; j < m.Columns; j++)
r[i, j] = m[i, j] * n;
return r;
}
internal static Matrix<<#=type#>> StrassenMultiply(Matrix<<#=type#>> A, Matrix<<#=type#>> B) // Smart matrix multiplication
{
Contract.Requires<ArgumentException>(A.Columns == B.Rows, "Wrong dimension of matrix!");
Matrix<<#=type#>> R;
int msize = Math.Max(Math.Max(A.Rows, A.Columns), Math.Max(B.Rows, B.Columns));
if (msize < 32)
{
R = Matrix<<#=type#>>.Zeroes(A.Rows, B.Columns);
for (int i = 0; i < R.Rows; i++)
for (int j = 0; j < R.Columns; j++)
for (int k = 0; k < A.Columns; k++)
R[i, j] += A[i, k] * B[k, j];
return R;
}
int size = 1; int n = 0;
while (msize > size) { size *= 2; n++; };
int h = size / 2;
Matrix<<#=type#>>[,] mField = new Matrix<<#=type#>>[n, 9];
/*
* 8x8, 8x8, 8x8, ...
* 4x4, 4x4, 4x4, ...
* 2x2, 2x2, 2x2, ...
* . . .
*/
int z;
for (int i = 0; i < n - 4; i++) // rows
{
z = (int)Math.Pow(2, n - i - 1);
for (int j = 0; j < 9; j++)
mField[i, j] = new Matrix<<#=type#>>(z, z);
}
SafeAplusBintoC(A, 0, 0, A, h, h, mField[0, 0], h);
SafeAplusBintoC(B, 0, 0, B, h, h, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 1], 1, mField); // (A11 + A22) * (B11 + B22);
SafeAplusBintoC(A, 0, h, A, h, h, mField[0, 0], h);
SafeACopytoC(B, 0, 0, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 2], 1, mField); // (A21 + A22) * B11;
SafeACopytoC(A, 0, 0, mField[0, 0], h);
SafeAminusBintoC(B, h, 0, B, h, h, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 3], 1, mField); //A11 * (B12 - B22);
SafeACopytoC(A, h, h, mField[0, 0], h);
SafeAminusBintoC(B, 0, h, B, 0, 0, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 4], 1, mField); //A22 * (B21 - B11);
SafeAplusBintoC(A, 0, 0, A, h, 0, mField[0, 0], h);
SafeACopytoC(B, h, h, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 5], 1, mField); //(A11 + A12) * B22;
SafeAminusBintoC(A, 0, h, A, 0, 0, mField[0, 0], h);
SafeAplusBintoC(B, 0, 0, B, h, 0, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 6], 1, mField); //(A21 - A11) * (B11 + B12);
SafeAminusBintoC(A, h, 0, A, h, h, mField[0, 0], h);
SafeAplusBintoC(B, 0, h, B, h, h, mField[0, 1], h);
StrassenMultiplyRun(mField[0, 0], mField[0, 1], mField[0, 1 + 7], 1, mField); // (A12 - A22) * (B21 + B22);
R = new Matrix<<#=type#>>(A.Rows, B.Columns); // result
/// C11
for (int i = 0; i < Math.Min(h, R.Rows); i++) // rows
for (int j = 0; j < Math.Min(h, R.Columns); j++) // cols
R[i, j] = mField[0, 1 + 1][i, j] + mField[0, 1 + 4][i, j] - mField[0, 1 + 5][i, j] + mField[0, 1 + 7][i, j];
/// C12
for (int i = 0; i < Math.Min(h, R.Rows); i++) // rows
for (int j = h; j < Math.Min(2 * h, R.Columns); j++) // cols
R[i, j] = mField[0, 1 + 3][i, j - h] + mField[0, 1 + 5][i, j - h];
/// C21
for (int i = h; i < Math.Min(2 * h, R.Rows); i++) // rows
for (int j = 0; j < Math.Min(h, R.Columns); j++) // cols
R[i, j] = mField[0, 1 + 2][i - h, j] + mField[0, 1 + 4][i - h, j];
/// C22
for (int i = h; i < Math.Min(2 * h, R.Rows); i++) // rows
for (int j = h; j < Math.Min(2 * h, R.Columns); j++) // cols
R[i, j] = mField[0, 1 + 1][i - h, j - h] - mField[0, 1 + 2][i - h, j - h] + mField[0, 1 + 3][i - h, j - h] + mField[0, 1 + 6][i - h, j - h];
return R;
}
// function for square matrix 2^N x 2^N
private static void StrassenMultiplyRun(Matrix<<#=type#>> A, Matrix<<#=type#>> B, Matrix<<#=type#>> C, int l, Matrix<<#=type#>>[,] f) // A * B into C, level of recursion, matrix field
{
int size = A.Rows;
int h = size / 2;
if (size < 32)
{
for (int i = 0; i < C.Rows; i++)
for (int j = 0; j < C.Columns; j++)
{
C[i, j] = 0;
for (int k = 0; k < A.Columns; k++) C[i, j] += A[i, k] * B[k, j];
}
return;
}
AplusBintoC(A, 0, 0, A, h, h, f[l, 0], h);
AplusBintoC(B, 0, 0, B, h, h, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 1], l + 1, f); // (A11 + A22) * (B11 + B22);
AplusBintoC(A, 0, h, A, h, h, f[l, 0], h);
ACopytoC(B, 0, 0, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 2], l + 1, f); // (A21 + A22) * B11;
ACopytoC(A, 0, 0, f[l, 0], h);
AminusBintoC(B, h, 0, B, h, h, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 3], l + 1, f); //A11 * (B12 - B22);
ACopytoC(A, h, h, f[l, 0], h);
AminusBintoC(B, 0, h, B, 0, 0, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 4], l + 1, f); //A22 * (B21 - B11);
AplusBintoC(A, 0, 0, A, h, 0, f[l, 0], h);
ACopytoC(B, h, h, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 5], l + 1, f); //(A11 + A12) * B22;
AminusBintoC(A, 0, h, A, 0, 0, f[l, 0], h);
AplusBintoC(B, 0, 0, B, h, 0, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 6], l + 1, f); //(A21 - A11) * (B11 + B12);
AminusBintoC(A, h, 0, A, h, h, f[l, 0], h);
AplusBintoC(B, 0, h, B, h, h, f[l, 1], h);
StrassenMultiplyRun(f[l, 0], f[l, 1], f[l, 1 + 7], l + 1, f); // (A12 - A22) * (B21 + B22);
/// C11
for (int i = 0; i < h; i++) // rows
for (int j = 0; j < h; j++) // cols
C[i, j] = f[l, 1 + 1][i, j] + f[l, 1 + 4][i, j] - f[l, 1 + 5][i, j] + f[l, 1 + 7][i, j];
/// C12
for (int i = 0; i < h; i++) // rows
for (int j = h; j < size; j++) // cols
C[i, j] = f[l, 1 + 3][i, j - h] + f[l, 1 + 5][i, j - h];
/// C21
for (int i = h; i < size; i++) // rows
for (int j = 0; j < h; j++) // cols
C[i, j] = f[l, 1 + 2][i - h, j] + f[l, 1 + 4][i - h, j];
/// C22
for (int i = h; i < size; i++) // rows
for (int j = h; j < size; j++) // cols
C[i, j] = f[l, 1 + 1][i - h, j - h] - f[l, 1 + 2][i - h, j - h] + f[l, 1 + 3][i - h, j - h] + f[l, 1 + 6][i - h, j - h];
}
private static void SafeAplusBintoC(Matrix<<#=type#>> A, int xa, int ya, Matrix<<#=type#>> B, int xb, int yb, Matrix<<#=type#>> C, int size)
{
for (int i = 0; i < size; i++) // rows
for (int j = 0; j < size; j++) // cols
{
C[i, j] = 0;
if (xa + j < A.Columns && ya + i < A.Rows)
C[i, j] += A[ya + i, xa + j];
if (xb + j < B.Columns && yb + i < B.Rows)
C[i, j] += B[yb + i, xb + j];
}
}
private static void SafeAminusBintoC(Matrix<<#=type#>> A, int xa, int ya, Matrix<<#=type#>> B, int xb, int yb, Matrix<<#=type#>> C, int size)
{
for (int i = 0; i < size; i++) // rows
for (int j = 0; j < size; j++) // cols
{
C[i, j] = 0;
if (xa + j < A.Columns && ya + i < A.Rows)
C[i, j] += A[ya + i, xa + j];
if (xb + j < B.Columns && yb + i < B.Rows)
C[i, j] -= B[yb + i, xb + j];
}
}
private static void SafeACopytoC(Matrix<<#=type#>> A, int xa, int ya, Matrix<<#=type#>> C, int size)
{
for (int i = 0; i < size; i++) // rows
for (int j = 0; j < size; j++) // cols
{
C[i, j] = 0;
if (xa + j < A.Columns && ya + i < A.Rows)
C[i, j] += A[ya + i, xa + j];
}
}
private static void AplusBintoC(Matrix<<#=type#>> A, int xa, int ya, Matrix<<#=type#>> B, int xb, int yb, Matrix<<#=type#>> C, int size)
{
for (int i = 0; i < size; i++) // rows
for (int j = 0; j < size; j++)
C[i, j] = A[ya + i, xa + j] + B[yb + i, xb + j];
}
private static void AminusBintoC(Matrix<<#=type#>> A, int xa, int ya, Matrix<<#=type#>> B, int xb, int yb, Matrix<<#=type#>> C, int size)
{
for (int i = 0; i < size; i++) // rows
for (int j = 0; j < size; j++)
C[i, j] = A[ya + i, xa + j] - B[yb + i, xb + j];
}
private static void ACopytoC(Matrix<<#=type#>> A, int xa, int ya, Matrix<<#=type#>> C, int size)
{
for (int i = 0; i < size; i++) // rows
for (int j = 0; j < size; j++)
C[i, j] = A[ya + i, xa + j];
}
<#}
#>
}
}
[Fact]
public void CreateGenericMatrixWithNumericTypes()
{
var matrixInt = new Matrix<int>(2, 2);
var matrixDouble = new Matrix<double>(2, 2);
var matrixFloat = new Matrix<float>(2, 2);
}
[Fact]
public void CreateGenericMatrixWithInvalidTypes()
{
Assert.Throws<NotSupportedException>(() => new Matrix<byte>(2, 2));
Assert.Throws<NotSupportedException>(() => new Matrix<SampleStruct>(2, 2));
Assert.Throws<NotSupportedException>(() => new Matrix<sbyte>(2, 2));
}
[Fact]
public void MultiplyAndAddMatrix()
{
var m1 = new Matrix<int>(2, 2);
var m2 = new Matrix<int>(2, 2);
m1[0, 0] = 1;
m1[1, 1] = 1;
m2[0, 0] = 1;
m2[1, 0] = 2;
m2[0, 1] = 3;
m2[1, 1] = 4;
var result = m1 * m2 + 2 * m1;
Assert.Equal(m2 + 2 * m1, result);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment