Skip to content

Instantly share code, notes, and snippets.

@DirkToewe
Created February 20, 2019 23:22
Show Gist options
  • Save DirkToewe/4f5a5cf5d0c65a2e49059713bc0e6206 to your computer and use it in GitHub Desktop.
Save DirkToewe/4f5a5cf5d0c65a2e49059713bc0e6206 to your computer and use it in GitHub Desktop.
Tensorflow4Scala for-comprehension experiments
package tf4s_experiments
import org.platanios.tensorflow.api._
import org.platanios.tensorflow.api.core.client.FeedMap
object ForComprehension_experiments
{
implicit class OutputOps[A: TF]( val tensor: Output[A] )
{
def map[B: TF]( map_fn: Output[A] => Output[B] ) =
{
val len: Output[Int] = (tf shape tensor) apply 0
val input = TensorArray.create[A](len) unstack tensor
tf.whileLoop[(TensorArray[B],Output[Int]),(Shape,Shape)](
predicateFn = _._2 < len,
bodyFn = {
case (res, i) => (
res write ( i, map_fn(input read i) ),
i+1
)
},
loopVariables = ( TensorArray.create[B](len), 0 )
) match {
case (result,_) => result.stack()
}
}
def withFilter( filter_fn: Output[A] => Output[Boolean] ): Output[A] =
{
val len: Output[Int] = (tf shape tensor) apply 0
val input = TensorArray.create[A](len) unstack tensor
tf.whileLoop[(TensorArray[A], Output[Int], Output[Int]),(Shape, Shape, Shape)](
predicateFn = _._2 < len,
bodyFn = {
case (res, i, j) =>
val input_i = input read i;
tf.cond(
filter_fn(input_i),
() => (res write (j, input_i), i+1, j+1),
() => (res , i+1, j )
)
},
loopVariables = ( TensorArray.create[A](0, true), 0, 0 )
) match {
case (result,_,_) => result.stack()
}
}
def flatMap[B: TF]( flatMap_fn: Output[A] => Output[B] ): Output[B] =
{
val outerLen: Output[Int] = (tf shape tensor) apply 0
val input = TensorArray.create[A](outerLen) unstack tensor
tf.whileLoop[(TensorArray[B], Output[Int], Output[Int]),(Shape, Shape, Shape)](
predicateFn = _._2 < outerLen,
bodyFn = {
case (res,i,j) =>
val innerTensor = flatMap_fn(input read i)
val innerLen = (tf shape innerTensor) apply 0
val inner = TensorArray.create[B](innerLen) unstack innerTensor
tf.whileLoop[(TensorArray[B], Output[Int], Output[Int]),(Shape, Shape, Shape)](
predicateFn = _._3 < innerLen,
bodyFn = {
case (res,j,k) => (
res write (j, inner read k),
j+1,
k+1
)
},
loopVariables = (res,j,0)
) match {
case (res,j,_) => (res,i+1,j)
}
},
loopVariables = ( TensorArray.create[B](0, true), 0, 0 )
) match {
case (result,_,_) => result.stack()
}
}
}
def main( args: Array[String] ): Unit =
{
val in = tf.placeholder[Double]()
val out = for( row <- in;
if row.max() > 3;
x <- row )
yield 2*x
val sess = Session()
try {
val a = Tensor[Double](
Array(1,2,3),
Array(4,5,6),
Array(7,8,9)
)
val Seq(b) = sess.run( FeedMap( in -> a ), Seq(out) )
println( b.summarize() )
} finally {
sess.close()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment