Skip to content

Instantly share code, notes, and snippets.

@pavanky
Created February 14, 2020 00:03
Show Gist options
  • Save pavanky/ea6e71e3e7e52c013db844b715723be0 to your computer and use it in GitHub Desktop.
Save pavanky/ea6e71e3e7e52c013db844b715723be0 to your computer and use it in GitHub Desktop.
package com.twitter.sample.lib
import java.io.File
import org.apache.commons.io.{FileUtils, IOUtils}
import org.tensorflow.TensorFlow
import com.twitter.logging.Logger
import java.nio.file.{Files, Paths}
object TfLoader {
// Used for tests
private[lib] var hasBeenLoaded = 0
private[this] val log = Logger.get()
private val loader = this.getClass.getClassLoader
private val libpath = System.getProperty("user.dir") + "/deepbird_libs"
private val fullOsName = System.getProperty("os.name")
private val osName = if (fullOsName == "Linux") "linux" else "darwin"
// loading the dynamic libraries twice on Linux segfaults
// ensure they're only loaded once
def loadLibrary(
libname: String,
resourcePath: String,
useTensorflowLoader: Boolean = false): Unit = {
val fname = s"$libname"
val input = loader.getResourceAsStream(s"$resourcePath$fname")
val fileOutStream = s"$libpath/$fname"
val output = FileUtils.openOutputStream(new File(fileOutStream))
IOUtils.copy(input, output)
if (useTensorflowLoader)
TensorFlow.loadLibrary(fileOutStream)
else
System.load(fileOutStream)
}
lazy val loadMkl: Unit = {
val tfArtifactResourcePath = s"org/tensorflow/native/$osName-x86_64/"
val mklResources = Seq(
("libiomp5.so", tfArtifactResourcePath),
("libmklml_intel.so", tfArtifactResourcePath))
if (fullOsName == "Linux") {
if (mklResources.forall({ case (libname, resourcePath) =>
loader.getResources(s"$resourcePath" + s"$libname").hasMoreElements()})) {
mklResources.foreach { case (libname, resourcePath) =>
loadLibrary(libname, resourcePath)
}
log.info("Successfully loaded MKL libraries")
}
else {
log.info("Skipping loading MKL libraries")
}
}
}
// libtensorflow_framework now needs to be loaded after loadMkl but before loadTfcontrib
lazy val loadTf: Unit = {
val libraryFile = if(osName == "darwin") {
"libtensorflow_framework.dylib"
}
else {
"libtensorflow_framework.so"
}
loadLibrary(libraryFile, s"org/tensorflow/native/$osName-x86_64/")
}
lazy val load: Unit = {
loadMkl
loadTf
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment