Skip to content

Instantly share code, notes, and snippets.

@AlexDBlack
Created July 26, 2018 02:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexDBlack/7f8ed83fc3660f545abc67983997e894 to your computer and use it in GitHub Desktop.
Save AlexDBlack/7f8ed83fc3660f545abc67983997e894 to your computer and use it in GitHub Desktop.
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import static org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor.VGG_MEAN_OFFSET_BGR;
public class Main {
private static final INDArray VGG_MEAN_OFFSET_BGR_2 = Nd4j.create(new double[]{103.939, 116.779, 123.68});
public static void main(String[] args) throws Exception {
String imagePath = "C:\\DL4J\\Git\\dl4j-test-resources\\target\\classes\\yolo\\VOC_TwoImage\\JPEGImages\\2008_003344.jpg";
String modelPathXception = "C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\modelimport\\keras\\examples\\xception\\xception_tf_keras_2.h5";
String modelPathResnet = "C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\modelimport\\keras\\examples\\resnet\\resnet50_weights_tf_dim_ordering_tf_kernels.h5";
String errorMessage;
ComputationGraph restoredXception;
ComputationGraph restoredResnet;
File modelFileXception = new File(modelPathXception);
File modelFileResnet = new File(modelPathResnet);
int tileSize = 256;
restoredXception = KerasModelImport.importKerasModelAndWeights(modelFileXception.getAbsolutePath(), new int[]{tileSize, tileSize, 3}, false);
restoredResnet = KerasModelImport.importKerasModelAndWeights(modelFileResnet.getAbsolutePath(), new int[]{tileSize, tileSize, 3}, false);
BufferedImage imgRGB = ImageIO.read(new File(imagePath));
DataNormalization scaler = new ImagePreProcessingScaler(-1.0, 1.0);
Java2DNativeImageLoader loader = new Java2DNativeImageLoader(tileSize, tileSize, 3);
//Load as RGB and BGR to check if there is any channel order difference
//INDArray imageArrayBGR = loader.asMatrix(imgRGB, true);
INDArray imageArrayRGBresnet = loader.asMatrix(imgRGB, false);
INDArray imageArrayRGBxception = loader.asMatrix(imgRGB, false);
INDArray imageArrayRGBresnetFlip = loader.asMatrix(imgRGB, true);
INDArray imageArrayRGBxceptionFlip = loader.asMatrix(imgRGB, true);
Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(imageArrayRGBresnet.dup(), VGG_MEAN_OFFSET_BGR_2, imageArrayRGBresnet, 1));
scaler.transform(imageArrayRGBxception);
Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(imageArrayRGBresnetFlip.dup(), VGG_MEAN_OFFSET_BGR_2, imageArrayRGBresnetFlip, 1));
scaler.transform(imageArrayRGBxceptionFlip);
int nIter = 100;
double[] durationRGB = new double[nIter];
double[] durationBGR = new double[nIter];
for( int i=0; i<nIter; i++ ) {
long startTime = System.nanoTime();
double outputProbRGB = restoredResnet.outputSingle(false, imageArrayRGBresnet).getDouble(0);
long endTime = System.nanoTime();
durationRGB[i] = (double) (endTime - startTime) / 1000000;
startTime = System.nanoTime();
double outputProbBGR = restoredResnet.outputSingle(false, imageArrayRGBresnetFlip).getDouble(0);
endTime = System.nanoTime();
durationBGR[i] = (double) (endTime - startTime) / 1000000;
}
System.out.println("RESNET");
System.out.println("Runtimes, RGB: " + Arrays.toString(durationRGB));
System.out.println("Runtimes, BGR: " + Arrays.toString(durationBGR));
INDArray arrRGB = Nd4j.create(Arrays.copyOfRange(durationRGB, 10, durationRGB.length));
INDArray arrBGR = Nd4j.create(Arrays.copyOfRange(durationBGR, 10, durationBGR.length));
System.out.println("Mean RGB: " + arrRGB.meanNumber() + " (stdev " + arrRGB.stdNumber() + ")");
System.out.println("Mean BGR: " + arrBGR.meanNumber() + " (stdev " + arrBGR.stdNumber() + ")");
durationRGB = new double[nIter];
durationBGR = new double[nIter];
for( int i=0; i<nIter; i++ ) {
long startTime = System.nanoTime();
double outputProbRGB = restoredXception.outputSingle(false, imageArrayRGBresnet).getDouble(0);
long endTime = System.nanoTime();
durationRGB[i] = (double) (endTime - startTime) / 1000000;
startTime = System.nanoTime();
double outputProbBGR = restoredXception.outputSingle(false, imageArrayRGBresnetFlip).getDouble(0);
endTime = System.nanoTime();
durationBGR[i] = (double) (endTime - startTime) / 1000000;
}
System.out.println("XCEPTION");
System.out.println("Runtimes, RGB: " + Arrays.toString(durationRGB));
System.out.println("Runtimes, BGR: " + Arrays.toString(durationBGR));
arrRGB = Nd4j.create(Arrays.copyOfRange(durationRGB, 10, durationRGB.length));
arrBGR = Nd4j.create(Arrays.copyOfRange(durationBGR, 10, durationBGR.length));
System.out.println("Mean RGB: " + arrRGB.meanNumber() + " (stdev " + arrRGB.stdNumber() + ")");
System.out.println("Mean BGR: " + arrBGR.meanNumber() + " (stdev " + arrBGR.stdNumber() + ")");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment