Skip to content

Instantly share code, notes, and snippets.

@piyo7
Last active February 20, 2017 08:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piyo7/e0b48487f240ed75a54e10fa6c82a5c8 to your computer and use it in GitHub Desktop.
Save piyo7/e0b48487f240ed75a54e10fa6c82a5c8 to your computer and use it in GitHub Desktop.
ScalaからTensorFlowのJava APIを呼びだすぞい ref: http://qiita.com/piyo7/items/d897d7156d87d29cad19
name := "tensorflow-scala"
scalaVersion := "2.12.1"
$ tree
.
├── build.sbt
├── jni
│   └── libtensorflow_jni.dylib
├── lib
│   └── libtensorflow-1.0.0-PREVIEW1.jar
└── src
   └── main
      └── scala
      └── Main.scala
$ sbt run -Djava.library.path=./jni
...
4, 10, 18
import org.tensorflow._
object Main extends App {
val graph = new Graph()
val a = graph.opBuilder("Const", "a").
setAttr("dtype", DataType.INT32).
setAttr("value", Tensor.create(Array(1, 2, 3))).
build().
output(0)
val b = graph.opBuilder("Const", "b").
setAttr("dtype", DataType.INT32).
setAttr("value", Tensor.create(Array(4, 5, 6))).
build().
output(0)
val c = graph.opBuilder("Mul", "c").
addInput(a).
addInput(b).
build().
output(0)
val session = new Session(graph)
val out = new Array[Int](3)
session.runner().fetch("c").run().get(0).copyTo(out)
println(out.mkString(", "))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment