Skip to content

Instantly share code, notes, and snippets.

@nikitakit
Created July 8, 2016 01:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nikitakit/d3ec270aee9d930267cec3efa844d5aa to your computer and use it in GitHub Desktop.
Save nikitakit/d3ec270aee9d930267cec3efa844d5aa to your computer and use it in GitHub Desktop.
Opening Tensorflow models in scala
sess = tf.InteractiveSession()
# [graph creation here]
with open(os.path.join(MODEL_DIR, "model.pb"), "wb") as f:
f.write(sess.graph_def.SerializeToString())
package epic.tensorflow
import org.bytedeco.javacpp.{tensorflow, BytePointer}
import org.bytedeco.javacpp.tensorflow.GraphDef
/**
* Created by kitaev on 1/22/16.
*/
class TensorflowModel(filename: String) {
val graphdef = new GraphDef()
val graphdef_data = java.nio.file.Files.readAllBytes(java.nio.file.Paths.get(filename))
val graphdef_data_ptr = new BytePointer(java.nio.ByteBuffer.wrap(graphdef_data))
val import_ok = tensorflow.ParseProtoUnlimited(graphdef, graphdef_data_ptr)
if (!import_ok) {
throw new Exception(s"Graphdef import failed from: ${filename}")
}
graphdef.set_version(0) // TODO(nikita): figure out how versioning is supposed to work
val s = TensorflowSession.sess.Extend(graphdef)
if (!s.ok()) {
throw new Exception(s"${s.ToString().getString()}")
}
}
package epic.tensorflow
import org.bytedeco.javacpp.tensorflow
import org.bytedeco.javacpp.tensorflow._
/**
* Created by kitaev on 1/22/16.
*/
object TensorflowSession {
val sess = new tensorflow.Session(new SessionOptions())
private val graphdef = new tensorflow.GraphDef()
sess.Create(graphdef)
// [project-specific code redacted]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment