Skip to content

Instantly share code, notes, and snippets.

@aldente39
Last active December 14, 2015 11:58
Show Gist options
  • Save aldente39/5082972 to your computer and use it in GitHub Desktop.
Save aldente39/5082972 to your computer and use it in GitHub Desktop.
object Strassen {
def add0(m:Array[Array[Double]], size:Int) = {
val log2 = Math.ceil(Math.log(size) / Math.log(2)).toInt
val next = Math.pow(2, log2).toInt
val res = Array.ofDim[Double](next, next)
for (i <- 0 until m.length) {
for (j <- 0 until m(0).length) {
res(i)(j) = m(i)(j)
}
}
res
}
def adda(a1:Array[Array[Double]], a2:Array[Array[Double]]) = {
val row = a1.length
val col = a1(0).length
val res = Array.ofDim[Double](row, col)
for (i <- 0 until row) {
for (j <- 0 until col) {
res(i)(j) = a1(i)(j) + a2(i)(j)
}
}
res
}
def suba(a1:Array[Array[Double]], a2:Array[Array[Double]]) = {
val row = a1.length
val col = a1(0).length
val res = Array.ofDim[Double](row, col)
for (i <- 0 until row) {
for (j <- 0 until col) {
res(i)(j) = a1(i)(j) - a2(i)(j)
}
}
res
}
def ikj(a1:Array[Array[Double]], a2:Array[Array[Double]]) = {
val row = a1.length
val col = a2(0).length
val res = Array.ofDim[Double](row, col)
for (i <- 0 until row) {
for (k <- 0 until a2.length) {
for (j <- 0 until col) {
res(i)(j) += a1(i)(k) * a2(k)(j)
}
}
}
res
}
def ijk(a1:Array[Array[Double]], a2:Array[Array[Double]]) = {
val row = a1.length
val col = a2(0).length
val res = Array.ofDim[Double](row, col)
for (i <- 0 until row) {
for (j <- 0 until col) {
for (k <- 0 until a2.length) {
res(i)(j) += a1(i)(k) * a2(k)(j)
}
}
}
res
}
def strassen_r(a:Array[Array[Double]], b:Array[Array[Double]], leaf:Int):Array[Array[Double]] = {
if (a.length <= leaf) {
ikj(a, b)
}
else {
val ns = a.length / 2
val a11 = Array.ofDim[Double](ns, ns)
val a12 = Array.ofDim[Double](ns, ns)
val a21 = Array.ofDim[Double](ns, ns)
val a22 = Array.ofDim[Double](ns, ns)
val b11 = Array.ofDim[Double](ns, ns)
val b12 = Array.ofDim[Double](ns, ns)
val b21 = Array.ofDim[Double](ns, ns)
val b22 = Array.ofDim[Double](ns, ns)
for (i <- 0 until ns) {
for (j <- 0 until ns) {
a11(i)(j) = a(i)(j)
a12(i)(j) = a(i)(j + ns)
a21(i)(j) = a(i + ns)(j)
a22(i)(j) = a(i + ns)(j + ns)
b11(i)(j) = b(i)(j)
b12(i)(j) = b(i)(j + ns)
b21(i)(j) = b(i + ns)(j)
b22(i)(j) = b(i + ns)(j + ns)
}
}
val p1 = strassen_r (adda (a11, a22), adda (b11, b22), leaf)
val p2 = strassen_r (adda (a21, a22), b11, leaf)
val p3 = strassen_r (a11, suba (b12, b22), leaf)
val p4 = strassen_r (a22, suba (b21, b11), leaf)
val p5 = strassen_r (adda (a11, a12), b22, leaf)
val p6 = strassen_r (suba (a21, a11), adda (b11, b12), leaf)
val p7 = strassen_r (suba (a12, a22), adda (b21, b22), leaf)
val c11 = suba (adda (p1, p4), adda (p5, p7))
val c12 = adda (p3, p5)
val c21 = adda (p2, p4)
val c22 = suba (adda (p1, p3), adda (p2, p6))
val res = Array.ofDim[Double](a.length,b(0).length)
for (i <- 0 until ns) {
for (j <- 0 until ns) {
res(i)(j) = c11(i)(j)
res(i)(j + ns) = c12(i)(j)
res(i + ns)(j) = c21(i)(j)
res(i + ns)(j + ns) = c22(i)(j)
}
}
res
}
}
def strassen(m1:Array[Array[Double]], m2:Array[Array[Double]], leaf:Int = 128) = {
val maxSize = List(m1.length, m1(0).length, m2.length, m2(0).length).max
val a = add0(m1, maxSize)
val b = add0(m2, maxSize)
val res = strassen_r(a, b, leaf)
res
}
def time(proc: => Unit) = {
val start = System.currentTimeMillis
proc
println((System.currentTimeMillis - start) + "msec.")
}
def main(args:Array[String]):Unit = {
val a = Array.ofDim[Double](1024, 1024)
time(strassen(a, a, 256))
time(ikj(a, a))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment