Recently, I have a project of using neural networks to learn and recognize certain patterns from text. The project by itself is an interesting problem which I ended up using conditional random field aided by recurrent neural network to solve. Once the model has been trained, I needed to write up a Java library to make the model accessible from JVM process.
The deep learning library I have been using is TensorFlow
. I know Pytorch
is picking up steams steadily in recent years, but to me TensorFlow
is really the best choice for building up any applications of industrial strength, given its incredibly diversified set of client languges one can choose from. Luckily Java
is one of those client languages.
TensorFlow
defines a class SavedModelBundle
that one can use to export the model onto the disk that any client languages TensorFlow
supports can then read from. When the model is imported into the process, the way to use it is as simple as: 1. feeding the input tensors with the actual input values, 2. fetching the values of the output tensors. How the graph is constructed is entirely transparent to the users.
For my model, the input tensors are just bunch of list of embedding indices. In Python, they are pretty much like the following,
word = tf.placeholder(dtype=tf.int32, shape=[None, None])
wordlength = tf.placeholder(dtype=tf.int32, shape=[None])
char = tf.placeholder(dtype=tf.int32, shape=[None, None, None])
charlength = tf.placeholder(dtype=tf.int32, shape=[None, None])
The reason I have embedding indices for chars as input to my model is because there are non-negligible amount of typos/errors in my texts. Char embeddings would be able to help the model to alleviate this problem.
The JVM language I decided to the library with was Scala
. After tokenizing the text into words, and then words into chars, and finally from tokens to their embedding indices, I ended up with the following corresponding to the above inputs,
val word: Vector[Vector[Int]] = ???
val wordLength: Vector[Int] = ???
val char: Vector[Vector[Vector[Int]]] = ???
val charLength: Vector[Vector[Int]] = ???
To convert the above vectors into tensors, I had my first draft as the following,
def toTensor1(vector: Vector[Int]): Tensor[java.lang.Integer] = {
val shape = Array(vector.size.toLong)
val buf = IntBuffer.wrap(vector.toArray)
Tensor.create(shape, buf)
}
def toTensor2(vector: Vector[Vector[Int]]): Tensor[java.lang.Integer] = {
val shapeX = vector.size.toLong
val shapeY = vector(0).size.toLong
val shape = Array(shapeX, shapeY)
val buf = IntBuffer.wrap(vector.flatten.toArray)
Tensor.create(shape, buf)
}
def toTensor3(vector: Vector[Vector[Vector[Int]]]): Tensor[java.lang.Integer] = {
val shapeX = vector.size.toLong
val shapeY = vector(0).size.toLong
val shapeZ = vector(0)(0).size.toLong
val shape = Array(shapeX, shapeY, shapeZ)
val buf = IntBuffer.wrap(vector.flatMap(_.flatten).toArray)
Tensor.create(shape, buf)
}
val wordTensor = toTensor2(word)
val wordLengthTensor = toTensor1(wordLength)
val charTensor = toTensor3(char)
val charLengthTensor = toTensor2(charLength)
There are two things to be noted above:
- the way
TensorFlow
creates a tensor inJava
is by asking for anArray[Long]
for the shape of tensor as well as a "flattened" view of the actual collection that carries the data. - there are quite a few repeated code to get these tensors.
As a functional / type programmer, duplicates of code are something I feel frustrated yet quite intrigued at the same time, because it means possibility for rewrites to eliminate these boilerplayes with higher ordered functions and types. After thinking about this, I created the following data structure,
sealed trait Rec[A]
object Rec {
private case class Node[A](a: A) extends Rec[A]
private case class Cons[A](a: Rec[A], as: Vector[Rec[A]]) extends Rec[A]
def node[A](a: A): Rec[A] = Node(a)
def cons[A](hd: Rec[A], tl: Rec[A]*): Rec[A] = Cons(hd, tl.toVector)
def unsafe[A](vector: Vector[Rec[A]]): Rec[A] = Cons(vector.head, vector.tail)
}
Rec[A]
is sort of like the Fix
type which is somewhat like the following,
case class Fix[F[_]](unfix: F[Fix[F]])
The difference is type F[_]
in my case is already known as Vector[_]
. Rec[_]
also flattens the co-recursion between the inner and outer type constructors, the reason of which will become obvious in a moment. It's time to convert the vectors into Rec
,
val wordR = Rec.unsafe(word.map(xs => Rec.unsafe(xs.map(Rec.node))))
val wordLengthR = Rec.unsafe(wordLength.map(Rec.node))
val charR = Rec.unsafe(char.map(xs => Rec.unsafe(xs => Rec.unsafe(xxs => Rec.unsafe(xxs.map(Rec.node))))))
val charLengthR = Rec.unsafe(charLength.map(xs => Rec.unsafe(xs.map(Rec.node))))
Now remember the goal is to have a unified function that can convert Vector[_]
of any depth to a TensorFlow
tensor. I chose to be a little awesome here by creating implicits to make the usage more convenient:
// this will be put in the companion object of Rec, so the private child classes can be accessed.
implicit final clase Ops(rec: Rec[A]) extends AnyVal {
def shape: Array[Long] = {
import scala.collection.mutable.ArrayBuffer
@scala.annotation.tailrec
def go(rec: Rec[A], buf: ArrayBuffer[Long] = ArrayBuffer.empty): Array[Long] =
rec match {
case Node(_) => buf.toArray
case Cons(x, xs) => go(x, buf += (1 + xs.size).toLong)
}
go(rec)
}
// keep in mind our implementation must be stack safe!
def flatten: Array[A] = {
import scala.util.control.TailCalls._
def go(rec: Rec[A]): TailRec[Vector[A]] =
rec match {
case Node(a) => done(Vector(a))
case Cons(a, as) =>
as.foldLeft(go(a))( (x, y) =>
for {
xv <- x
yv <- go(y)
} yield xv ++ yv
)
}
go(rec).result.toArray
}
def toIntTensor(implicit N : Numeric[A]): Tensor[java.lang.Integer] = {
val buf = IntBuffer.wrap(flatten.map(N.toInt))
Tensor.create(shape, buf)
}
}
Now inputs can be converted to tensors,
val wordTensor = wordR.toTensor
val wordLengthTensor = wordLengthR.toTensor
val charTensor = charR.toTensor
val charLengthTensor = charLengthR.toTensor
As promised - one function to handle conversion for all cases. Keep in mind it is only possible because all inputs, after conversion, are represented by a unified type Rec[_]
. Therefore a higher order function can accept all of them and apply pattern match to uncover the internal structure of them. That said, this brings another problem ... Can you spot what it is ?
Look at types of wordR
and charR
for example, is there any difference between them from the surface of their types ? There is not, as they both have the same type of Rec[Int]
! If I happen to switch wordR
and charR
as inputs, the compiler won't be able to tell me it's wrong, because they indeed have the same type. I need to find out a way to make them look slighly different from the perspective of type, but not too different so they can still be unified under the umbrella of a more generic type.
My final attempt revolves encoding Rec[_]
with the type level natural numbers. shapeless
has provided an excellent implmentation of it, so I just shamelessly copied the relevant part into my code,
sealed trait Nat { type N <: Nat }
final case class Succ[P <: Nat]() extends Nat { type N = Succ[P] }
final class Zero extends Nat with Serializable { type N = Zero }
object Nat {
type _0 = Zero
type _1 = Succ[_0]
type _2 = Succ[_1]
type _3 = Succ[_2]
}
Rec
is now redefined as the following,
sealed trait Rec[A, Nat]
object Rec {
private case class Node[A](a: A) extends Rec[A, Zero]
private case class Cons[A, N <: Nat](a: Rec[A, N], as: Vector[Rec[A, N]]) extends Rec[A, Succ[N]]
def node[A](a: A): Rec[A, Zero] = Node(a)
def cons[A, N <: Nat](a: Rec[A, N], as: Rec[A, N]*): Rec[A, Succ[N]] = Cons(a, as.toVector)
def unsafe[A, N <: Nat](vector: Vector[Rec[A, N]]): Rec[A, Succ[N]] = Cons(vector.head, vector.tail)
}
the added type parameter N <: Nat
encodes the depth of self-recursion of Rec
. For example, Rec[Rec[Int], _2]
has depth of self-recursive of 2 which is consistent with what _2
suggests. It is now not possible to mix or switch vector of different depth, since they are different types now. The way to convert inputs to the refined Rec
is identical with the previous,
val wordR = Rec.unsafe(word.map(xs => Rec.unsafe(xs.map(Rec.node))))
val wordLengthR = Rec.unsafe(wordLength.map(Rec.node))
val charR = Rec.unsafe(char.map(xs => Rec.unsafe(xs => Rec.unsafe(xxs => Rec.unsafe(xxs.map(Rec.node))))))
val charLengthR = Rec.unsafe(charLength.map(xs => Rec.unsafe(xs.map(Rec.node))))
To compute the shape and flatten view of the refined Rec
, it is not possible to just use a regular function like what I did previously , since again Rec
now recurs. Instead, I need type classes,
trait Shape[N <: Nat] {
def apply[A](rec: Rec[A, N]): Array[Long]
}
object Shape {
implicit val zero : Shape[Zero] =
new Shape[Zero] { def apply[A](rec: Rec[A, Zero]): Array[Long] = Array.empty }
implicit def succ[N <: Nat](implicit prev: Shape[N]): Shape[Succ[N]] =
new Shape[Succ[N]] {
def apply[A](rec: Rec[A, Succ[N]]): Array[Long] =
rec match {
case Cons(x, xs) =>
val dim = (1 + xs.size).toLong
val dims = prev(x)
dim +: dims
}
}
}
import scala.util.control.TailCalls._
trait Flatten[N <: Nat] {
def apply[A](rec: Rec[A, N]): TailRec[Vector[A]]
}
object Flatten {
implicit val zero: Flatten[Zero] =
new Flatten[Zero] {
def apply[A](rec: Rec[A, Zero]): TailRec[Vector[A]] =
rec match {
case Node(a) => done(Vector(a))
}
}
implicit def succ[N <: Nat](implicit prev: Flatten[N]): Flatten[Succ[N]] =
new Flatten[Succ[N]] {
def apply[A](rec: Rec[A, Succ[N]]): TailRec[Vector[A]] =
rec match {
case Cons(a, as) =>
as.foldLeft(prev(a))( (x, y) =>
for {
xv <- x
yv <- prev(y)
} yield xv ++ yv
)
}
}
}
The above type classes are quite similar with what shapeless
has to provide for vast amount of operations o HList
and Coproduct
. However the difference is my type classes work on the type recursions and type level natural numbers.
With the type classes, the convenient implicits for converting vector to tensor is quite simple. In truth, the function toIntTensor
is almost the same with the previous implementation.
implicit final class Ops[A, N <: Nat](val rec: Rec[A, N]) extends AnyVal {
def shape(implicit S : Shape[N]): Array[Long] = S(rec)
def flatten(implicit F : Flatten[N], C : ClassTag[A]): Array[A] = F(rec).result.toArray
def toIntTensor(implicit N : Numeric[A], S : Shape[N], F : Flatten[N], C : ClassTag[A]): Tensor[java.lang.Integer] = {
val buf = IntBuffer.wrap(flatten.map(N.toInt))
Tensor.create(shape, buf)
}
}
The way of converting inputs to tensors stays the same,
val wordTensor = wordR.toTensor
val wordLengthTensor = wordLengthR.toTensor
val charTensor = charR.toTensor
val charLengthTensor = charLengthR.toTensor
Does the rewrites actually reduce the sheer amount of code ? Not really ... actually in my case the number of LOC gets increased. However the art of reducing boilerplates, in my opinion, is not by counting the number of LOC before and after the refactoring. Rather, it is about relieving the psychological burden by replacing the tedious copy/paste of code with one or few simple rules, which makes the program simpler to reason and easier to maintain. To that end, I am quite happy for the work I had to put in to make this happen.