Skip to content

Instantly share code, notes, and snippets.

@shengc
Last active June 12, 2018 04:11
Show Gist options
  • Save shengc/9ffd13ca2a05fee5b5339336f925cb2b to your computer and use it in GitHub Desktop.
Save shengc/9ffd13ca2a05fee5b5339336f925cb2b to your computer and use it in GitHub Desktop.
A polymorphic way of converting generic Scala collections to TensorFlow tensors

Recently, I have a project of using neural networks to learn and recognize certain patterns from text. The project by itself is an interesting problem which I ended up using conditional random field aided by recurrent neural network to solve. Once the model has been trained, I needed to write up a Java library to make the model accessible from JVM process.

The deep learning library I have been using is TensorFlow. I know Pytorch is picking up steams steadily in recent years, but to me TensorFlow is really the best choice for building up any applications of industrial strength, given its incredibly diversified set of client languges one can choose from. Luckily Java is one of those client languages.

TensorFlow defines a class SavedModelBundle that one can use to export the model onto the disk that any client languages TensorFlow supports can then read from. When the model is imported into the process, the way to use it is as simple as: 1. feeding the input tensors with the actual input values, 2. fetching the values of the output tensors. How the graph is constructed is entirely transparent to the users.

For my model, the input tensors are just bunch of list of embedding indices. In Python, they are pretty much like the following,

word = tf.placeholder(dtype=tf.int32, shape=[None, None])
wordlength = tf.placeholder(dtype=tf.int32, shape=[None])
char = tf.placeholder(dtype=tf.int32, shape=[None, None, None])
charlength = tf.placeholder(dtype=tf.int32, shape=[None, None])

The reason I have embedding indices for chars as input to my model is because there are non-negligible amount of typos/errors in my texts. Char embeddings would be able to help the model to alleviate this problem.

The JVM language I decided to the library with was Scala. After tokenizing the text into words, and then words into chars, and finally from tokens to their embedding indices, I ended up with the following corresponding to the above inputs,

val word: Vector[Vector[Int]] = ???
val wordLength: Vector[Int] = ???
val char: Vector[Vector[Vector[Int]]] = ???
val charLength: Vector[Vector[Int]] = ???

To convert the above vectors into tensors, I had my first draft as the following,

def toTensor1(vector: Vector[Int]): Tensor[java.lang.Integer] = {
  val shape = Array(vector.size.toLong)
  val buf = IntBuffer.wrap(vector.toArray)
  Tensor.create(shape, buf)
}
def toTensor2(vector: Vector[Vector[Int]]): Tensor[java.lang.Integer] = {
  val shapeX = vector.size.toLong
  val shapeY = vector(0).size.toLong
  val shape = Array(shapeX, shapeY)
  val buf = IntBuffer.wrap(vector.flatten.toArray)
  Tensor.create(shape, buf)
}
def toTensor3(vector: Vector[Vector[Vector[Int]]]): Tensor[java.lang.Integer] = {
  val shapeX = vector.size.toLong
  val shapeY = vector(0).size.toLong
  val shapeZ = vector(0)(0).size.toLong
  val shape = Array(shapeX, shapeY, shapeZ)
  val buf = IntBuffer.wrap(vector.flatMap(_.flatten).toArray)
  Tensor.create(shape, buf)
}

val wordTensor = toTensor2(word)
val wordLengthTensor = toTensor1(wordLength)
val charTensor = toTensor3(char)
val charLengthTensor = toTensor2(charLength)

There are two things to be noted above:

  1. the way TensorFlow creates a tensor in Java is by asking for an Array[Long] for the shape of tensor as well as a "flattened" view of the actual collection that carries the data.
  2. there are quite a few repeated code to get these tensors.

As a functional / type programmer, duplicates of code are something I feel frustrated yet quite intrigued at the same time, because it means possibility for rewrites to eliminate these boilerplayes with higher ordered functions and types. After thinking about this, I created the following data structure,

sealed trait Rec[A]
object Rec {
  private case class Node[A](a: A) extends Rec[A]
  private case class Cons[A](a: Rec[A], as: Vector[Rec[A]]) extends Rec[A]
  
  def node[A](a: A): Rec[A] = Node(a)
  def cons[A](hd: Rec[A], tl: Rec[A]*): Rec[A] = Cons(hd, tl.toVector)
  def unsafe[A](vector: Vector[Rec[A]]): Rec[A] = Cons(vector.head, vector.tail)
}

Rec[A] is sort of like the Fix type which is somewhat like the following,

case class Fix[F[_]](unfix: F[Fix[F]])

The difference is type F[_] in my case is already known as Vector[_]. Rec[_] also flattens the co-recursion between the inner and outer type constructors, the reason of which will become obvious in a moment. It's time to convert the vectors into Rec,

val wordR = Rec.unsafe(word.map(xs => Rec.unsafe(xs.map(Rec.node))))
val wordLengthR = Rec.unsafe(wordLength.map(Rec.node))
val charR = Rec.unsafe(char.map(xs => Rec.unsafe(xs => Rec.unsafe(xxs => Rec.unsafe(xxs.map(Rec.node))))))
val charLengthR = Rec.unsafe(charLength.map(xs => Rec.unsafe(xs.map(Rec.node))))

Now remember the goal is to have a unified function that can convert Vector[_] of any depth to a TensorFlow tensor. I chose to be a little awesome here by creating implicits to make the usage more convenient:

// this will be put in the companion object of Rec, so the private child classes can be accessed.
implicit final clase Ops(rec: Rec[A]) extends AnyVal {
  def shape: Array[Long] = {
    import scala.collection.mutable.ArrayBuffer
    @scala.annotation.tailrec
    def go(rec: Rec[A], buf: ArrayBuffer[Long] = ArrayBuffer.empty): Array[Long] = 
      rec match {
        case Node(_) => buf.toArray
        case Cons(x, xs) => go(x, buf += (1 + xs.size).toLong) 
      }
    go(rec)
  }
  // keep in mind our implementation must be stack safe!
  def flatten: Array[A] = {
    import scala.util.control.TailCalls._
    def go(rec: Rec[A]): TailRec[Vector[A]] = 
      rec match {
        case Node(a) => done(Vector(a))
        case Cons(a, as) =>
          as.foldLeft(go(a))( (x, y) =>
            for {
              xv <- x
              yv <- go(y)
            } yield xv ++ yv
          )
      }
    go(rec).result.toArray
  }
  
  def toIntTensor(implicit N : Numeric[A]): Tensor[java.lang.Integer] = {
    val buf = IntBuffer.wrap(flatten.map(N.toInt))
    Tensor.create(shape, buf)
  }
}

Now inputs can be converted to tensors,

val wordTensor = wordR.toTensor
val wordLengthTensor = wordLengthR.toTensor
val charTensor = charR.toTensor
val charLengthTensor = charLengthR.toTensor

As promised - one function to handle conversion for all cases. Keep in mind it is only possible because all inputs, after conversion, are represented by a unified type Rec[_]. Therefore a higher order function can accept all of them and apply pattern match to uncover the internal structure of them. That said, this brings another problem ... Can you spot what it is ?

Look at types of wordR and charR for example, is there any difference between them from the surface of their types ? There is not, as they both have the same type of Rec[Int]! If I happen to switch wordR and charR as inputs, the compiler won't be able to tell me it's wrong, because they indeed have the same type. I need to find out a way to make them look slighly different from the perspective of type, but not too different so they can still be unified under the umbrella of a more generic type.

My final attempt revolves encoding Rec[_] with the type level natural numbers. shapeless has provided an excellent implmentation of it, so I just shamelessly copied the relevant part into my code,

sealed trait Nat { type N <: Nat }
final case class Succ[P <: Nat]() extends Nat { type N = Succ[P] }
final class Zero extends Nat with Serializable { type N = Zero }
object Nat {
  type _0 = Zero
  type _1 = Succ[_0]
  type _2 = Succ[_1]
  type _3 = Succ[_2]
}

Rec is now redefined as the following,

sealed trait Rec[A, Nat]
object Rec {
  private case class Node[A](a: A) extends Rec[A, Zero]
  private case class Cons[A, N <: Nat](a: Rec[A, N], as: Vector[Rec[A, N]]) extends Rec[A, Succ[N]]
  
  def node[A](a: A): Rec[A, Zero] = Node(a)
  def cons[A, N <: Nat](a: Rec[A, N], as: Rec[A, N]*): Rec[A, Succ[N]] = Cons(a, as.toVector)
  def unsafe[A, N <: Nat](vector: Vector[Rec[A, N]]): Rec[A, Succ[N]] = Cons(vector.head, vector.tail)
}

the added type parameter N <: Nat encodes the depth of self-recursion of Rec. For example, Rec[Rec[Int], _2] has depth of self-recursive of 2 which is consistent with what _2 suggests. It is now not possible to mix or switch vector of different depth, since they are different types now. The way to convert inputs to the refined Rec is identical with the previous,

val wordR = Rec.unsafe(word.map(xs => Rec.unsafe(xs.map(Rec.node))))
val wordLengthR = Rec.unsafe(wordLength.map(Rec.node))
val charR = Rec.unsafe(char.map(xs => Rec.unsafe(xs => Rec.unsafe(xxs => Rec.unsafe(xxs.map(Rec.node))))))
val charLengthR = Rec.unsafe(charLength.map(xs => Rec.unsafe(xs.map(Rec.node))))

To compute the shape and flatten view of the refined Rec, it is not possible to just use a regular function like what I did previously , since again Rec now recurs. Instead, I need type classes,

trait Shape[N <: Nat] {
  def apply[A](rec: Rec[A, N]): Array[Long]
}
object Shape {
  implicit val zero : Shape[Zero] = 
    new Shape[Zero] { def apply[A](rec: Rec[A, Zero]): Array[Long] = Array.empty }
  implicit def succ[N <: Nat](implicit prev: Shape[N]): Shape[Succ[N]] = 
    new Shape[Succ[N]] {
      def apply[A](rec: Rec[A, Succ[N]]): Array[Long] = 
        rec match {
          case Cons(x, xs) =>
            val dim = (1 + xs.size).toLong
            val dims = prev(x)
            dim +: dims
        }
    }
}

import scala.util.control.TailCalls._
trait Flatten[N <: Nat] {
  def apply[A](rec: Rec[A, N]): TailRec[Vector[A]]
}
object Flatten {
  implicit val zero: Flatten[Zero] = 
    new Flatten[Zero] {
      def apply[A](rec: Rec[A, Zero]): TailRec[Vector[A]] = 
        rec match {
          case Node(a) => done(Vector(a))
        }
    }
  implicit def succ[N <: Nat](implicit prev: Flatten[N]): Flatten[Succ[N]] = 
    new Flatten[Succ[N]] {
      def apply[A](rec: Rec[A, Succ[N]]): TailRec[Vector[A]] = 
        rec match {
          case Cons(a, as) =>
            as.foldLeft(prev(a))( (x, y) =>
              for {
                xv <- x
                yv <- prev(y)
              } yield xv ++ yv
            )
        }
    }
}

The above type classes are quite similar with what shapeless has to provide for vast amount of operations o HList and Coproduct. However the difference is my type classes work on the type recursions and type level natural numbers.

With the type classes, the convenient implicits for converting vector to tensor is quite simple. In truth, the function toIntTensor is almost the same with the previous implementation.

implicit final class Ops[A, N <: Nat](val rec: Rec[A, N]) extends AnyVal {
  def shape(implicit S : Shape[N]): Array[Long] = S(rec)
  def flatten(implicit F : Flatten[N], C : ClassTag[A]): Array[A] = F(rec).result.toArray
  def toIntTensor(implicit N : Numeric[A], S : Shape[N], F : Flatten[N], C : ClassTag[A]): Tensor[java.lang.Integer] = {
    val buf = IntBuffer.wrap(flatten.map(N.toInt))
    Tensor.create(shape, buf)
  }
}

The way of converting inputs to tensors stays the same,

val wordTensor = wordR.toTensor
val wordLengthTensor = wordLengthR.toTensor
val charTensor = charR.toTensor
val charLengthTensor = charLengthR.toTensor

Does the rewrites actually reduce the sheer amount of code ? Not really ... actually in my case the number of LOC gets increased. However the art of reducing boilerplates, in my opinion, is not by counting the number of LOC before and after the refactoring. Rather, it is about relieving the psychological burden by replacing the tedious copy/paste of code with one or few simple rules, which makes the program simpler to reason and easier to maintain. To that end, I am quite happy for the work I had to put in to make this happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment