Created
July 26, 2018 02:16
-
-
Save AlexDBlack/7f8ed83fc3660f545abc67983997e894 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.datavec.image.loader.Java2DNativeImageLoader; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; | |
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | |
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import org.nd4j.linalg.factory.Nd4j; | |
import javax.imageio.ImageIO; | |
import java.awt.image.BufferedImage; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.Arrays; | |
import static org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor.VGG_MEAN_OFFSET_BGR; | |
public class Main { | |
private static final INDArray VGG_MEAN_OFFSET_BGR_2 = Nd4j.create(new double[]{103.939, 116.779, 123.68}); | |
public static void main(String[] args) throws Exception { | |
String imagePath = "C:\\DL4J\\Git\\dl4j-test-resources\\target\\classes\\yolo\\VOC_TwoImage\\JPEGImages\\2008_003344.jpg"; | |
String modelPathXception = "C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\modelimport\\keras\\examples\\xception\\xception_tf_keras_2.h5"; | |
String modelPathResnet = "C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\modelimport\\keras\\examples\\resnet\\resnet50_weights_tf_dim_ordering_tf_kernels.h5"; | |
String errorMessage; | |
ComputationGraph restoredXception; | |
ComputationGraph restoredResnet; | |
File modelFileXception = new File(modelPathXception); | |
File modelFileResnet = new File(modelPathResnet); | |
int tileSize = 256; | |
restoredXception = KerasModelImport.importKerasModelAndWeights(modelFileXception.getAbsolutePath(), new int[]{tileSize, tileSize, 3}, false); | |
restoredResnet = KerasModelImport.importKerasModelAndWeights(modelFileResnet.getAbsolutePath(), new int[]{tileSize, tileSize, 3}, false); | |
BufferedImage imgRGB = ImageIO.read(new File(imagePath)); | |
DataNormalization scaler = new ImagePreProcessingScaler(-1.0, 1.0); | |
Java2DNativeImageLoader loader = new Java2DNativeImageLoader(tileSize, tileSize, 3); | |
//Load as RGB and BGR to check if there is any channel order difference | |
//INDArray imageArrayBGR = loader.asMatrix(imgRGB, true); | |
INDArray imageArrayRGBresnet = loader.asMatrix(imgRGB, false); | |
INDArray imageArrayRGBxception = loader.asMatrix(imgRGB, false); | |
INDArray imageArrayRGBresnetFlip = loader.asMatrix(imgRGB, true); | |
INDArray imageArrayRGBxceptionFlip = loader.asMatrix(imgRGB, true); | |
Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(imageArrayRGBresnet.dup(), VGG_MEAN_OFFSET_BGR_2, imageArrayRGBresnet, 1)); | |
scaler.transform(imageArrayRGBxception); | |
Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(imageArrayRGBresnetFlip.dup(), VGG_MEAN_OFFSET_BGR_2, imageArrayRGBresnetFlip, 1)); | |
scaler.transform(imageArrayRGBxceptionFlip); | |
int nIter = 100; | |
double[] durationRGB = new double[nIter]; | |
double[] durationBGR = new double[nIter]; | |
for( int i=0; i<nIter; i++ ) { | |
long startTime = System.nanoTime(); | |
double outputProbRGB = restoredResnet.outputSingle(false, imageArrayRGBresnet).getDouble(0); | |
long endTime = System.nanoTime(); | |
durationRGB[i] = (double) (endTime - startTime) / 1000000; | |
startTime = System.nanoTime(); | |
double outputProbBGR = restoredResnet.outputSingle(false, imageArrayRGBresnetFlip).getDouble(0); | |
endTime = System.nanoTime(); | |
durationBGR[i] = (double) (endTime - startTime) / 1000000; | |
} | |
System.out.println("RESNET"); | |
System.out.println("Runtimes, RGB: " + Arrays.toString(durationRGB)); | |
System.out.println("Runtimes, BGR: " + Arrays.toString(durationBGR)); | |
INDArray arrRGB = Nd4j.create(Arrays.copyOfRange(durationRGB, 10, durationRGB.length)); | |
INDArray arrBGR = Nd4j.create(Arrays.copyOfRange(durationBGR, 10, durationBGR.length)); | |
System.out.println("Mean RGB: " + arrRGB.meanNumber() + " (stdev " + arrRGB.stdNumber() + ")"); | |
System.out.println("Mean BGR: " + arrBGR.meanNumber() + " (stdev " + arrBGR.stdNumber() + ")"); | |
durationRGB = new double[nIter]; | |
durationBGR = new double[nIter]; | |
for( int i=0; i<nIter; i++ ) { | |
long startTime = System.nanoTime(); | |
double outputProbRGB = restoredXception.outputSingle(false, imageArrayRGBresnet).getDouble(0); | |
long endTime = System.nanoTime(); | |
durationRGB[i] = (double) (endTime - startTime) / 1000000; | |
startTime = System.nanoTime(); | |
double outputProbBGR = restoredXception.outputSingle(false, imageArrayRGBresnetFlip).getDouble(0); | |
endTime = System.nanoTime(); | |
durationBGR[i] = (double) (endTime - startTime) / 1000000; | |
} | |
System.out.println("XCEPTION"); | |
System.out.println("Runtimes, RGB: " + Arrays.toString(durationRGB)); | |
System.out.println("Runtimes, BGR: " + Arrays.toString(durationBGR)); | |
arrRGB = Nd4j.create(Arrays.copyOfRange(durationRGB, 10, durationRGB.length)); | |
arrBGR = Nd4j.create(Arrays.copyOfRange(durationBGR, 10, durationBGR.length)); | |
System.out.println("Mean RGB: " + arrRGB.meanNumber() + " (stdev " + arrRGB.stdNumber() + ")"); | |
System.out.println("Mean BGR: " + arrBGR.meanNumber() + " (stdev " + arrBGR.stdNumber() + ")"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment