Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lukasjapan/2af23822e3bd5c20d64870c1cc38aca2 to your computer and use it in GitHub Desktop.
Save lukasjapan/2af23822e3bd5c20d64870c1cc38aca2 to your computer and use it in GitHub Desktop.
using-deeplearning4j-to-distinguish-between-cats-and-dogs.3.kt
// load vgg16 model from model zoo
val model = VGG16().initPretrained(PretrainedType.IMAGENET) as ComputationGraph
// restore the new model that was saved to a file
val catsdogsModel = ModelSerializer.restoreMultiLayerNetwork(javaClass.getResource("/catdogmodel.dl4j").openStream())
// get vgg16 labels
val image = FileInputStream("input.png")
val input = NativeImageLoader(224, 224, 3).asMatrix(image).also { VGG16ImagePreProcessor().transform(it) }
val vgg16labels = vgg16model.outputSingle(input)
// get cat/dog prediction values
val output = catsdogsModel.output(vgg16labels)
val cat = output.getDouble(0)
val dog = output.getDouble(1)
// make the prediction
if(cat > dog) println("cat") else println("dog")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment