Skip to content

Instantly share code, notes, and snippets.

@geekprogramming
Last active June 2, 2016 06:46
Show Gist options
  • Save geekprogramming/e2e8b3f279086dec21deb0be8cfc77fb to your computer and use it in GitHub Desktop.
Save geekprogramming/e2e8b3f279086dec21deb0be8cfc77fb to your computer and use it in GitHub Desktop.
// Predict function
public static int[][] toIntArrayArray(BufferedImage image) {
int w = image.getWidth();
int h = image.getHeight();
int[][] ret = new int[h][w];
int j;
if(image.getRaster().getNumDataElements() == 1) {
WritableRaster i = image.getRaster();
for(j = 0; j < h; ++j) {
for(int j1 = 0; j1 < w; ++j1) {
ret[j][j1] = i.getSample(j1, j, 0);
}
}
} else {
for(int var8 = 0; var8 < h; ++var8) {
for(j = 0; j < w; ++j) {
ret[var8][j] = image.getRGB(j, var8);
}
}
}
return ret;
}
//Steps which appy data before train and then apply for predict
public static BufferedImage preprocess(BufferedImage img) {
opencv_core.Mat original = ImageOpenCvUtils.toMat(img);
opencv_core.Mat grayScale = ImageOpenCvUtils.toGray(original);
opencv_core.Mat maximizeContrast = ImageOpenCvUtils.maximizeContrastMat(grayScale, 1);
opencv_core.Mat output = new opencv_core.Mat(maximizeContrast.size(),maximizeContrast.type());
opencv_imgproc.resize(maximizeContrast,output,new opencv_core.Size(23,43));
return ImageOpenCvUtils.toBufferedImage(output);
}
public static double predict(BufferedImage img,MultiLayerNetwork model) {
img = preprocess(img);
int[][] ret = toIntArrayArray(img);
ImageLoader loader = new ImageLoader();
INDArray input = ArrayUtil.toNDArray(ArrayUtil.flatten(ret));
INDArray output = model.output(input);
output = Nd4j.argMax(output, 1);
return output.getDouble(0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment