Skip to content

Instantly share code, notes, and snippets.

@biboudis
Last active January 11, 2024 14:50
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save biboudis/3c4823d905dbadd6d61a7ce3fa47ff64 to your computer and use it in GitHub Desktop.
Save biboudis/3c4823d905dbadd6d61a7ce3fa47ff64 to your computer and use it in GitHub Desktop.
Type your matrices for great good
import scala.compiletime.ops._
import scala.compiletime.ops.int._
import scala.compiletime.ops.any._
/**
* Type your matrices for great good: a Haskell library of typed matrices and applications (functional pearl)
* https://dl.acm.org/doi/10.1145/3406088.3409019
*/
object Test {
object Internal {
enum Matrix[E, C, R] {
case One[E](e: E) extends Matrix[E, Unit, Unit]
case Join[E, R, A, B](m1: Matrix[E, A, R], m2: Matrix[E, B, R]) extends Matrix[E, Either[A, B], R]
case Fork[E, C, A, B](m1: Matrix[E, C, A], m2: Matrix[E, C, B]) extends Matrix[E, C, Either[A, B]]
}
type Count[D] <: Int = D match {
case Null => 0
case Unit => 1
case Either[a, b] => a + b
case (a, b) => a * b
}
type FromNat[N <: Int] = N match {
case 0 => Null
case 1 => Unit
case N => FromNatB[N % 2 == 0, FromNat[N / 2]]
}
type FromNatB[B <: Boolean, M] = B match {
case true => Either[M, M]
case false => Either[Unit, Either[M, M]]
}
type Normalize[D] = D match {
case Either[a, b] => Either[Normalize[a], Normalize[b]]
case D => FromNat[Count[D]]
}
def abideJF[Cols, Rows](m: Matrix[Int, Cols, Rows]): Matrix[Int, Cols, Rows] = {
import Matrix._
m match {
case Join(Fork(a, c), Fork(b, d)) => Fork(Join(abideJF(a), abideJF(b)), (Join(abideJF(c), abideJF(d))))
case One(e) => One(e)
case Join(a, b) => Join(abideJF(a), abideJF(b))
case Fork(a, b) => Fork(abideJF(a), abideJF(b))
}
}
def zipWith[Cols, Rows](f: Int => Int => Int, m1: Matrix[Int, Cols, Rows], m2: Matrix[Int, Cols, Rows]): Matrix[Int, Cols, Rows] = {
import Matrix._
(m1, m2) match {
case (One(a), One(b)) => One(f(a)(b))
case (Join(a, b), Join(c, d)) => Join(zipWith(f, a, c), zipWith(f, b, d))
case (Fork(a, b), Fork(c, d)) => Fork(zipWith(f, a, c), zipWith(f, b, d))
case (x@Fork, y@Join) => zipWith(f, x, abideJF(y))
case (x@Join, y@Fork) => zipWith(f, abideJF(x), y)
}
}
extension [Rows, Cols](m1: Matrix[Int, Rows, Cols])
def + (m2: Matrix[Int, Rows, Cols]) = zipWith(a => b => a + b, m1, m2)
def comp[CR, Rows, Cols](m1: Matrix[Int, CR, Rows], m2: Matrix[Int, Cols, CR]): Matrix[Int, Cols, Rows] = {
import Matrix._
(m1, m2) match {
case (One(a), One(b)) => One(a * b)
case (Join(a, b), Fork(c, d)) => comp(a, c) + comp(b, d)
case (Fork(a, b), c) => Fork(comp(a, c), comp(b, c))
case (c, Join(a, b)) => Join(comp(c, a), comp(c, b))
}
}
trait FromLists[Cols, Rows] {
def fromLists(arr: Array[Array[Int]]): Matrix[Int, Cols, Rows]
}
}
opaque type Matrix[Cols <: Int, Rows <: Int] = Internal.Matrix[Int, Internal.FromNat[Cols], Internal.FromNat[Rows]]
trait Tests {
import Internal.Matrix._
def iden2x2(e: Int): Matrix[2, 2] = {
Fork(Join(One(1), One(0)),
Join(One(0), One(1)))
}
def ones1x3(e: Int): Matrix[1, 3] = {
Fork(One(1), Fork(One(1), One(1)))
}
def ones3x1(e: Int): Matrix[3, 1] = {
Join(One(1), Join(One(1), One(1)))
}
def iden3x3(e: Int): Matrix[3, 3] = {
Fork(Join(One(1), Join(One(0), One(0))),
Fork(Join(One(0), Join(One(1), One(0))),
Join(One(0), Join(One(0), One(1)))))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment