Created
February 20, 2019 23:22
-
-
Save DirkToewe/4f5a5cf5d0c65a2e49059713bc0e6206 to your computer and use it in GitHub Desktop.
Tensorflow4Scala for-comprehension experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package tf4s_experiments | |
import org.platanios.tensorflow.api._ | |
import org.platanios.tensorflow.api.core.client.FeedMap | |
object ForComprehension_experiments | |
{ | |
implicit class OutputOps[A: TF]( val tensor: Output[A] ) | |
{ | |
def map[B: TF]( map_fn: Output[A] => Output[B] ) = | |
{ | |
val len: Output[Int] = (tf shape tensor) apply 0 | |
val input = TensorArray.create[A](len) unstack tensor | |
tf.whileLoop[(TensorArray[B],Output[Int]),(Shape,Shape)]( | |
predicateFn = _._2 < len, | |
bodyFn = { | |
case (res, i) => ( | |
res write ( i, map_fn(input read i) ), | |
i+1 | |
) | |
}, | |
loopVariables = ( TensorArray.create[B](len), 0 ) | |
) match { | |
case (result,_) => result.stack() | |
} | |
} | |
def withFilter( filter_fn: Output[A] => Output[Boolean] ): Output[A] = | |
{ | |
val len: Output[Int] = (tf shape tensor) apply 0 | |
val input = TensorArray.create[A](len) unstack tensor | |
tf.whileLoop[(TensorArray[A], Output[Int], Output[Int]),(Shape, Shape, Shape)]( | |
predicateFn = _._2 < len, | |
bodyFn = { | |
case (res, i, j) => | |
val input_i = input read i; | |
tf.cond( | |
filter_fn(input_i), | |
() => (res write (j, input_i), i+1, j+1), | |
() => (res , i+1, j ) | |
) | |
}, | |
loopVariables = ( TensorArray.create[A](0, true), 0, 0 ) | |
) match { | |
case (result,_,_) => result.stack() | |
} | |
} | |
def flatMap[B: TF]( flatMap_fn: Output[A] => Output[B] ): Output[B] = | |
{ | |
val outerLen: Output[Int] = (tf shape tensor) apply 0 | |
val input = TensorArray.create[A](outerLen) unstack tensor | |
tf.whileLoop[(TensorArray[B], Output[Int], Output[Int]),(Shape, Shape, Shape)]( | |
predicateFn = _._2 < outerLen, | |
bodyFn = { | |
case (res,i,j) => | |
val innerTensor = flatMap_fn(input read i) | |
val innerLen = (tf shape innerTensor) apply 0 | |
val inner = TensorArray.create[B](innerLen) unstack innerTensor | |
tf.whileLoop[(TensorArray[B], Output[Int], Output[Int]),(Shape, Shape, Shape)]( | |
predicateFn = _._3 < innerLen, | |
bodyFn = { | |
case (res,j,k) => ( | |
res write (j, inner read k), | |
j+1, | |
k+1 | |
) | |
}, | |
loopVariables = (res,j,0) | |
) match { | |
case (res,j,_) => (res,i+1,j) | |
} | |
}, | |
loopVariables = ( TensorArray.create[B](0, true), 0, 0 ) | |
) match { | |
case (result,_,_) => result.stack() | |
} | |
} | |
} | |
def main( args: Array[String] ): Unit = | |
{ | |
val in = tf.placeholder[Double]() | |
val out = for( row <- in; | |
if row.max() > 3; | |
x <- row ) | |
yield 2*x | |
val sess = Session() | |
try { | |
val a = Tensor[Double]( | |
Array(1,2,3), | |
Array(4,5,6), | |
Array(7,8,9) | |
) | |
val Seq(b) = sess.run( FeedMap( in -> a ), Seq(out) ) | |
println( b.summarize() ) | |
} finally { | |
sess.close() | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment