Skip to content

Instantly share code, notes, and snippets.

@montardon
Last active July 19, 2019 14:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save montardon/1dd63ad960dc7a4c201b5fdc6045c96b to your computer and use it in GitHub Desktop.
Save montardon/1dd63ad960dc7a4c201b5fdc6045c96b to your computer and use it in GitHub Desktop.
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.modelimport.keras.basic;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.UNet;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
/**
Load a Unet zoo model and try to train it on
https://github.com/zhixuhao/unet.git
membranes images (ISBI 2015 challenge)
to get segmentation of input image
*/
public class TrainUnetWithMembraneImages {
// Logger
private static final Logger log = LoggerFactory.getLogger(TrainUnetWithMembraneImages.class);
// Input image size 512x512x3 for this Unet model. It outputs 512x512x1 data.
private static final int WIDTH = 512;
private static final int HEIGHT = 512;
private static final int CHANNELS = 3;
public static void main(String[] args) {
String inputDataDirectory = "";
if (args.length>0) {
inputDataDirectory = args[0];
} else {
usage();
}
// Define cache location for downloaded models and data if necessary
// By default cache is put in $HOME/.deeplearning4j repository
// DL4JResources.setBaseDirectory(new File("/home/local_user/cache"));
// Load model from DL4J Zoo (model and weights)
UNet unetModel = UNet.builder().build();
//ComputationGraph unetGraph = new ComputationGraph(unetModel.graphBuilder().build());
ComputationGraph model = null;
try {
// read saved model
// model = ModelSerializer.restoreComputationGraph(new File("modelDL4J_unet"));
// or instantiate zoo model
model = (ComputationGraph) unetModel.initPretrained(PretrainedType.SEGMENT);
} catch (IOException e) {
e.printStackTrace();
return;
}
// Print model vertices (debug)
for (GraphVertex vertex: model.getVertices()) {
System.out.println(vertex.getVertexName());
if (vertex.getVertexName().equalsIgnoreCase("input_1")) {
Layer layer = vertex.getLayer();
if (layer != null) {
System.out.println(layer.toString());
}
}
}
// from Gitter deeplearning4j forum
// ComputationGraph net = ...initPretrained(...)
// net = new TransferLearning.GraphBuilder(net)
// .removeVertexAndConnections("activation_23")
// .addLayer("output", new CnnLossLayer.Builder().activation(Activation.SOFTMAX)
// .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "conv22")
// .build();
model = new TransferLearning.GraphBuilder(model)
.removeVertexAndConnections("activation_23")
.addLayer("output", new CnnLossLayer.Builder().activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "conv2d_23")
.setOutputs("output")
.build();
long startTime = System.currentTimeMillis();
// Define input images and labels
// training and images and label /unet/data/membrane/train/img and
// training and images and label /unet/data/membrane/train/label
int numEpochs = 10;
int numSamples = 30;
try {
File[] images = new File[numSamples];
File[] labels = new File[numSamples];
for (int i=0; i < numSamples; i++) {
images[i] = new File(inputDataDirectory+"/membrane/train/image/"+i+".png");
labels[i] = new File(inputDataDirectory+"/membrane/train/label/"+i+".png");
}
UIServer uiServer = UIServer.getInstance();
//Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
uiServer.attach(statsStorage);
//Then add the StatsListener to collect this information from the network, as it trains
model.setListeners(new PerformanceListener(30),new ScoreIterationListener(30),new StatsListener(statsStorage));
for (int epochs = 0; epochs < numEpochs; epochs++) {
log.warn("Epoch "+epochs);
for (int i=0; i <numSamples; i++) {
INDArray input = transformImageToBatch(images[i]);
INDArray answer = transformLabelToBatch(labels[i]);
model.fit(new INDArray[]{input}, new INDArray[]{answer}, null, null);
}
}
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("Elapsed time="+(System.currentTimeMillis()-startTime)/1000.f+"s");
try {
ModelSerializer.writeModel((Model)model,new File("modelDL4J_unet"),true);
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("Train done.");
System.out.println("Let's call exit()");
System.exit(0);
}
public static void usage() {
log.warn("Please call this program with a image filename as argument");
}
public static INDArray transformImageToBatch(File... imageFiles) throws IOException {
INDArray imageBatch = Nd4j.create(DataType.FLOAT, imageFiles.length, CHANNELS ,HEIGHT,WIDTH);
int index = 0;
NativeImageLoader loader = new NativeImageLoader(WIDTH,HEIGHT,CHANNELS);
for (File imageFile: imageFiles) {
INDArray arr = loader.asMatrix(imageFile);
// log.warn(arr.shapeInfoToString());
arr = arr.divi(255.0f);
imageBatch.putRow(index++, arr);
}
return imageBatch;
}
public static INDArray transformLabelToBatch(File... imageFiles) throws IOException {
INDArray imageBatch = Nd4j.create(DataType.FLOAT, imageFiles.length, 1 ,HEIGHT,WIDTH);
int index = 0;
NativeImageLoader nativeImageLoader = new NativeImageLoader(HEIGHT,WIDTH,1);
for (File imageFile: imageFiles) {
INDArray arr = nativeImageLoader.asMatrix(imageFile);
// log.warn(arr.shapeInfoToString());
arr = arr.divi(255.0f);
imageBatch.putRow(index++, arr);
}
return imageBatch;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment