Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Typed type tensors in scala
sealed trait Tensor[V] {
val n, m: Int
def apply(vs: List[V], vds: List[V => Double]): Double
}
case class TUnit[V](v: V) extends Tensor[V] {
val n = 1
val m = 0
def apply(vs: List[V], vds: List[V => Double]): Double = vds head(v)
}
case class TCoUnit[V](vd: V => Double) extends Tensor[V] {
val n = 0
val m = 1
def apply(vs: List[V], vds: List[V => Double]): Double = vd(vs head)
}
case class TProduct[V](t: Tensor[V], u: Tensor[V]) extends Tensor[V] {
val n = t.n + u.n
val m = t.m + u.m
def apply(vs: List[V], vds: List[V => Double]): Double = {
val (fvs, lvs) = vs splitAt(t.m)
val (fvds, lvds) = vds splitAt(t.n)
t.apply(fvs, fvds) * u.apply(lvs, lvds)
}
}
def main(args: Array[String]): Unit = {
val vec = List(1., 2, 3)
val uec = List(-1., -1, -1)
val duec = (v: List[Double]) => dot(v, uec)
val tensor = TProduct(TProduct(TUnit(vec), TUnit(vec)), TCoUnit(duec));
println(tensor apply(Sized(vec), Sized(duec, duec)))
//=> -216.0
}
def dot(v: List[Double], u: List[Double]) = {
v zip u map Function.tupled(_ * _) sum
}
sealed trait Tensor[V, N <: Nat, M <: Nat] {
def apply(vs: Sized[List[V], M] {type A = V},
vds: Sized[List[V => Double], N] {type A = V => Double}): Double
}
case class TUnit[V](v: V) extends Tensor [V, Nat._1, Nat._0]{
def apply(vs: Sized[List[V], Nat._0] {type A = V},
vds: Sized[List[V => Double], Nat._1] {type A = V => Double}): Double = {
val hd = vds.head
hd(v)
}
}
case class TCoUnit[V](vd: V => Double) extends Tensor [V, Nat._0, Nat._1]{
def apply(vs: Sized[List[V], Nat._1] {type A = V},
vds: Sized[List[V => Double], Nat._0] {type A = V => Double}): Double = vd(vs head)
}
case class TProduct[V, X <: Nat, Y <: Nat, S <: Nat, T <: Nat, K <: Nat, L <: Nat](t: Tensor[V, X, Y],
u: Tensor[V, S, T])
(implicit val k: DiffAux[K, X, S],
implicit val l: DiffAux[L, Y, T],
implicit val toX: ToInt[X],
implicit val toY: ToInt[Y]) extends Tensor[V, K, L] {
def apply(vs: Sized[List[V], L] {type A = V},
vds: Sized[List[V => Double], K] {type A = V => Double}): Double = {
val (fvs, lvs) = vs.splitAt[Y]
val (fvds, lvds) = vds.splitAt[X]
t.apply(fvs, fvds) * u.apply(lvs, lvds)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment