Skip to content

Instantly share code, notes, and snippets.

@tomthetrainer
Created November 21, 2017 22:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomthetrainer/7509ab754b58e0a360b872ba9518aaad to your computer and use it in GitHub Desktop.
Save tomthetrainer/7509ab754b58e0a360b872ba9518aaad to your computer and use it in GitHub Desktop.
package ai.skymind.training.demos;
import org.apache.log4j.BasicConfigurator;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import javax.servlet.MultipartConfigElement;
import java.io.File;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import static spark.Spark.*;
//import org.nd4j.linalg.dataset.api.preprocessor.
/**
* Created by tomhanlon on 1/25/17.
* File that you need is available here
* https://github.com/tomthetrainer/KerasWorkshop/releases/download/v0.10/vgg16.zip
*/
public class VGG16SparkJavaWebApp {
public static void main(String[] args) throws Exception {
BasicConfigurator.configure();
/*
Demonstration instructions
This takes at least 4 minutes to load
When loaded You will see jetty activity in the log
Point browser at http://localhost:4567/VGGpredict
And load an image into the form
*/
// Load Neural Network from serialized format
//File savedNetwork = new ClassPathResource("vgg16.zip").getFile();
File savedNetwork = new File("/tmp/vgg16.zip");
// File can be found here
// https://github.com/tomthetrainer/KerasWorkshop/releases
ComputationGraph vgg16 = ModelSerializer.restoreComputationGraph(savedNetwork);
// make upload directory to store loaded images
File uploadDir = new File("upload");
uploadDir.mkdir(); // create the upload directory if it doesn't exist
// form to allow user to choose image to upload
String form = "<form method='post' action='getPredictions' enctype='multipart/form-data'>\n" +
" <input type='file' name='uploaded_file'>\n" +
" <button>Upload picture</button>\n" +
"</form>";
// spark java settings to display form or results
staticFiles.location("/Users/tomhanlon/SkyMind/webcontent"); // Static files
get("/hello", (req, res) -> "Hello World");
get("VGGpredict", (req, res) -> form);
//post("getPredictions",(req, res) -> "GET RESULTS");
post("/getPredictions", (req, res) -> {
Path tempFile = Files.createTempFile(uploadDir.toPath(), "", "");
req.attribute("org.eclipse.jetty.multipartConfig", new MultipartConfigElement("/temp"));
try (InputStream input = req.raw().getPart("uploaded_file").getInputStream()) { // getPart needs to use same "name" as input field in form
Files.copy(input, tempFile, StandardCopyOption.REPLACE_EXISTING);
}
File file = tempFile.toFile();
// define native image loaders
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
INDArray image = loader.asMatrix(file);
// Scale image in same manner as network was trained on
DataNormalization scaler = new VGG16ImagePreProcessor();
scaler.transform(image);
file.delete();
INDArray[] output = vgg16.output(false,image);
// just added
//Map<String, INDArray> mine = vgg16.feedForward();
//System.out.println(mine);
// just added
String predictions = TrainedModels.VGG16.decodePredictions(output[0]);
return "<h1> '" + predictions + "' </h1>" +
"Would you like to try another" +
form;
//return "<h1>Your image is: '" + tempFile.getName(1).toString() + "' </h1>";
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment