Skip to content

Instantly share code, notes, and snippets.

@mehalter
Last active April 5, 2016 14:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mehalter/2b102e4e0adb628e74e385abed0814a4 to your computer and use it in GitHub Desktop.
Save mehalter/2b102e4e0adb628e74e385abed0814a4 to your computer and use it in GitHub Desktop.
Matrix parallel processing
package com.mehalter
import java.util.concurrent.Executors
import scala.annotation.tailrec
import scalaz.concurrent.{Strategy, Task}
import scalaz.stream._
object matrix {
def main(args: Array[String]): Unit = {
val S = Strategy.Executor(Executors.newFixedThreadPool(4, Strategy.DefaultDaemonThreadFactory))
val a: Array[Array[Int]] = Array(Array(5, 6), Array(7, 8))
val b: Array[Array[Int]] = Array(Array(1, 2), Array(3, 4))
val c: Array[Array[Int]] = Array(Array(9, 2), Array(7, 1))
val d: Array[Array[Int]] = Array(Array(9, 2), Array(7, 1))
val e: Array[Array[Int]] = Array(Array(9, 2), Array(7, 1))
val f: Array[Array[Int]] = Array(Array(9, 2), Array(7, 1))
val g: Array[Array[Int]] = Array(Array(9, 2), Array(7, 1))
val out = multiplyMatricesPara(a, b, c, d, e, f, g).runLast.run
if (out.nonEmpty)
println(out.get.toList.map(_.toList))
val out2 = multiplyMatrix(a, b)(S)
println(out2.toList.map(_.toList))
val x: Array[Int] = Array(1, 2, 3)
val y: Array[Int] = Array(4, 5, 6)
val out3 = dotMultiply(x, y)(S)
println(out3)
}
def multiplyMatrices(x: Array[Array[Int]]*)(implicit S: Strategy): Array[Array[Int]] =
if (x.length < 1) Array[Array[Int]]()
else if (x.length == 1) x(0)
else {
@tailrec
def loop(ans: Array[Array[Int]], i: Int): Array[Array[Int]] = {
if (i >= x.length) ans
else loop(multiplyMatrix(ans, x(i))(S), i + 1)
}
loop(x(0), 1)
}
def multiplyMatricesPara(matrices: Array[Array[Int]]*)(implicit S: Strategy): Process[Task, Array[Array[Int]]] = {
if (matrices.length < 2) Process.emit(matrices.head)
else {
val (left, right) = matrices.splitAt(matrices.length / 2)
def splitting(moreMatrices: Seq[Array[Array[Int]]]): Process[Task, Array[Array[Int]]] =
Process.emit(moreMatrices).flatMap(multiplyMatricesPara)
(splitting(left) wye splitting(right)) (wye.yipWith(multiplyMatrix(_, _)(S)))(S)
}
}
def multiplyMatrix(a: Array[Array[Int]], b: Array[Array[Int]])(implicit S: Strategy): Array[Array[Int]] = {
val pairProcess = Process.emitAll(getPairs(a, b)).map(pair => Process.eval(Task(calculate(pair, a, b))))
val finished = merge.mergeN(pairProcess)(S)
val out: Array[Array[Int]] = Array.fill(a.length, b(0).length)(0)
finished.map { case (x, (r, c)) => out(r)(c) = x }.run.run
out
}
def getPairs(a: Array[Array[Int]], b: Array[Array[Int]]): Seq[(Int, Int)] = {
val height = a.indices
val width = b(0).indices
for {
r <- height
c <- width
} yield (r, c)
}
def calculate(p: (Int, Int), a: Array[Array[Int]], b: Array[Array[Int]]): (Int, (Int, Int)) = p match {
case (r, c) =>
//println(f"Hello from ${Thread.currentThread().getId}%d")
val pairs = for {
i <- b.indices
} yield a(r)(i) * b(i)(c)
(pairs.sum, p)
}
def dotMultiply(a: Array[Int], b: Array[Int])(implicit S: Strategy): Int = {
merge.mergeN(Process.emitAll(a.indices).map(i => Process.eval(Task(a(i) * b(i)))))(S).runLog.run.sum
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment