Created
February 21, 2017 08:07
-
-
Save tomthetrainer/c1ee7fe0908f129ee1eb65dee882f411 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.VGGwebDemo; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.Utils.ImageNetLabels; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor; | |
import javax.servlet.MultipartConfigElement; | |
import java.io.File; | |
import java.io.InputStream; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.nio.file.StandardCopyOption; | |
import static spark.Spark.options; | |
import static spark.Spark.get; | |
import static spark.Spark.post; | |
import static spark.Spark.staticFiles; | |
import java.util.Collections; | |
import java.util.Map; | |
import java.util.TreeMap; | |
import org.nd4j.shade.jackson.databind.ObjectMapper; | |
/** | |
* Created by tomhanlon on 1/25/17. | |
*/ | |
public class VGG16SparkJavaWebApp { | |
public static void main(String[] args) throws Exception { | |
File locationToSave = new File("vgg16.zip"); | |
ComputationGraph vgg16 = ModelSerializer.restoreComputationGraph(locationToSave); | |
// make upload directory | |
File uploadDir = new File("upload"); | |
uploadDir.mkdir(); // create the upload directory if it doesn't exist | |
// form | |
String form = "<form method='post' action='getPredictions' enctype='multipart/form-data'>\n" + | |
" <input type='file' name='uploaded_file'>\n" + | |
" <button>Upload picture</button>\n" + | |
"</form>"; | |
staticFiles.location("/Users/tomhanlon/SkyMind/webcontent");// Static files | |
//CorsFilter.apply(); | |
//options("/", (req, res) -> { | |
//Appease something | |
// }); | |
options("/*", (req, res) -> "Hello World"); | |
get("/hello", (req, res) -> "Hello World"); | |
get("VGGpredict", (req, res) -> form); | |
//post("getPredictions",(req, res) -> "GET RESULTS"); | |
post("/getPredictions", (req, res) -> { | |
Path tempFile = Files.createTempFile(uploadDir.toPath(), "", ""); | |
req.attribute("org.eclipse.jetty.multipartConfig", new MultipartConfigElement("/temp")); | |
try (InputStream input = req.raw().getPart("uploaded_file").getInputStream()) { // getPart needs to use same "name" as input field in form | |
Files.copy(input, tempFile, StandardCopyOption.REPLACE_EXISTING); | |
} | |
//logInfo(req, tempFile); | |
//return "<h1>You uploaded this image:<h1><img src='" + tempFile.getFileName() + "'>"; | |
File file = tempFile.toFile(); | |
//File file = new File(path); | |
NativeImageLoader loader = new NativeImageLoader(224, 224, 3); | |
INDArray image = loader.asMatrix(file); | |
file.delete(); | |
DataNormalization scaler = new VGG16ImagePreProcessor(); | |
scaler.transform(image); | |
//System.out.print(image); | |
INDArray[] output = vgg16.output(false,image); | |
// sort to get top 5 | |
INDArray[] sorted = Nd4j.sortWithIndices(output[0], 1, false); | |
// sorted map for results | |
//Map<Float, String> map = new TreeMap<Float, String>(Collections.reverseOrder()); | |
//VGGResults vggResults = new VGGResults(label,pred); | |
VGGResults[] vggResultsArray = new VGGResults[5]; | |
// Get top 5 | |
for (int i = 0; i < 5; i++) { | |
// Get prediction percent | |
Float prediction = sorted[1].getFloat(i) * 100; | |
// extract label for prediction | |
String Label = ImageNetLabels.getLabel(sorted[0].getInt(i)); | |
// put both in Result array | |
vggResultsArray[i] = new VGGResults(Label, prediction); | |
} | |
// Jackson obect mapper | |
// ##### I AM HERE ##### | |
ObjectMapper mapper = new ObjectMapper(); | |
String predictions = mapper.writeValueAsString(vggResultsArray); | |
//String predictions = mapper.writeValueAsString(map); | |
String predictionmunge = "{" + | |
"\"data\":" + predictions + "}"; | |
// return "<h4> '" + predictions + "' </h4>" + | |
// "Would you like to try another" + | |
// form; | |
return predictionmunge ; | |
//return "<h1>Your image is: '" + tempFile.getName(1).toString() + "' </h1>"; | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment