Skip to content

Instantly share code, notes, and snippets.

@heyrutvik
Created December 5, 2020 10:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save heyrutvik/988544440c0ed26a126e32ccebb7097c to your computer and use it in GitHub Desktop.
Save heyrutvik/988544440c0ed26a126e32ccebb7097c to your computer and use it in GitHub Desktop.
matrix multiplication
package math
import cats.instances.list._
import cats.syntax.all._
import io.circe.generic.auto._
import io.circe.parser._
import scala.io.Source
/**
* file: input.json
*
* [
* {
* "row": 3,
* "column": 2,
* "elements": [1,2,3,4,2,1]
* },
* {
* "row": 2,
* "column": 3,
* "elements": [1,2,3,4,2,1]
* }
* ]
*/
final case class Matrix[A: Numeric](row: Int, column: Int, elements: List[A]) {
require(rowVectors.length == row && rowVectors.forall(_.length == column), "invalid matrix")
/**
* |1 2|
* |3 4| =(rowVectors)=> [[1 2] [3 4] [2 1]]
* |2 1|
*/
private lazy val rowVectors: List[List[A]] = elements.sliding(column, column).toList
/**
* |1 2 3|
* |4 2 1| =(columnVectors)=> [[1 4] [2 2] [3 1]]
*/
private lazy val columnVectors: List[List[A]] =
Range(0, column)
.flatMap(idx => rowVectors.traverse(row => List(row(idx))))
.toList
/**
* ([1 2], [1 4]) =(reduce)=> [1x1 + 2x4] => [1 + 8] => 9
*/
private def reduce(as: List[A], bs: List[A]): A =
as.zip(bs).map { case (a, b) => implicitly[Numeric[A]].times(a, b) }.sum
def *(other: Matrix[A]): Matrix[A] = Matrix[A](
row,
other.column,
for {
row <- rowVectors
column <- other.columnVectors
} yield reduce(row, column)
)
private def padding(n: A, max: Int): String = {
val str = n.toString
val took = str.length
val need = max - took
str + " ".repeat(need)
}
override def toString: String = {
val p = rowVectors.flatMap(_.map(_.toString.length)).max
val d = s"$row x $column\n"
val m = rowVectors.map(row => row.map(n => padding(n, p)).mkString("|", " ", "|")).mkString("\n")
d + m
}
}
object Demo extends App {
val inputJson = Source.fromResource("input.json").mkString
val matrices: List[Matrix[Int]] = decode[List[Matrix[Int]]](inputJson).fold(fa => throw fa, identity)
println(matrices.reduce[Matrix[Int]] { case (m1, m2) => m1 * m2 })
/**
* 3 x 3
* |9 6 5 |
* |19 14 13|
* |6 6 7 |
*/
}
/*
"org.typelevel" %% "cats-core" % "2.1.1",
"io.circe" %% "circe-core" % "0.12.3",
"io.circe" %% "circe-generic" % "0.12.3",
"io.circe" %% "circe-parser" % "0.12.3",
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment