-
-
Save montardon/1dd63ad960dc7a4c201b5fdc6045c96b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/******************************************************************************* | |
* Copyright (c) 2015-2019 Skymind, Inc. | |
* | |
* This program and the accompanying materials are made available under the | |
* terms of the Apache License, Version 2.0 which is available at | |
* https://www.apache.org/licenses/LICENSE-2.0. | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
* License for the specific language governing permissions and limitations | |
* under the License. | |
* | |
* SPDX-License-Identifier: Apache-2.0 | |
******************************************************************************/ | |
package org.deeplearning4j.examples.modelimport.keras.basic; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.nn.api.Layer; | |
import org.deeplearning4j.nn.api.Model; | |
import org.deeplearning4j.nn.conf.layers.CnnLossLayer; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.graph.vertex.GraphVertex; | |
import org.deeplearning4j.nn.transferlearning.TransferLearning; | |
import org.deeplearning4j.optimize.listeners.PerformanceListener; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.deeplearning4j.zoo.PretrainedType; | |
import org.deeplearning4j.zoo.model.UNet; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.buffer.DataType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.File; | |
import java.io.IOException; | |
/** | |
Load a Unet zoo model and try to train it on | |
https://github.com/zhixuhao/unet.git | |
membranes images (ISBI 2015 challenge) | |
to get segmentation of input image | |
*/ | |
public class TrainUnetWithMembraneImages { | |
// Logger | |
private static final Logger log = LoggerFactory.getLogger(TrainUnetWithMembraneImages.class); | |
// Input image size 512x512x3 for this Unet model. It outputs 512x512x1 data. | |
private static final int WIDTH = 512; | |
private static final int HEIGHT = 512; | |
private static final int CHANNELS = 3; | |
public static void main(String[] args) { | |
String inputDataDirectory = ""; | |
if (args.length>0) { | |
inputDataDirectory = args[0]; | |
} else { | |
usage(); | |
} | |
// Define cache location for downloaded models and data if necessary | |
// By default cache is put in $HOME/.deeplearning4j repository | |
// DL4JResources.setBaseDirectory(new File("/home/local_user/cache")); | |
// Load model from DL4J Zoo (model and weights) | |
UNet unetModel = UNet.builder().build(); | |
//ComputationGraph unetGraph = new ComputationGraph(unetModel.graphBuilder().build()); | |
ComputationGraph model = null; | |
try { | |
// read saved model | |
// model = ModelSerializer.restoreComputationGraph(new File("modelDL4J_unet")); | |
// or instantiate zoo model | |
model = (ComputationGraph) unetModel.initPretrained(PretrainedType.SEGMENT); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
return; | |
} | |
// Print model vertices (debug) | |
for (GraphVertex vertex: model.getVertices()) { | |
System.out.println(vertex.getVertexName()); | |
if (vertex.getVertexName().equalsIgnoreCase("input_1")) { | |
Layer layer = vertex.getLayer(); | |
if (layer != null) { | |
System.out.println(layer.toString()); | |
} | |
} | |
} | |
// from Gitter deeplearning4j forum | |
// ComputationGraph net = ...initPretrained(...) | |
// net = new TransferLearning.GraphBuilder(net) | |
// .removeVertexAndConnections("activation_23") | |
// .addLayer("output", new CnnLossLayer.Builder().activation(Activation.SOFTMAX) | |
// .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "conv22") | |
// .build(); | |
model = new TransferLearning.GraphBuilder(model) | |
.removeVertexAndConnections("activation_23") | |
.addLayer("output", new CnnLossLayer.Builder().activation(Activation.SOFTMAX) | |
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "conv2d_23") | |
.setOutputs("output") | |
.build(); | |
long startTime = System.currentTimeMillis(); | |
// Define input images and labels | |
// training and images and label /unet/data/membrane/train/img and | |
// training and images and label /unet/data/membrane/train/label | |
int numEpochs = 10; | |
int numSamples = 30; | |
try { | |
File[] images = new File[numSamples]; | |
File[] labels = new File[numSamples]; | |
for (int i=0; i < numSamples; i++) { | |
images[i] = new File(inputDataDirectory+"/membrane/train/image/"+i+".png"); | |
labels[i] = new File(inputDataDirectory+"/membrane/train/label/"+i+".png"); | |
} | |
UIServer uiServer = UIServer.getInstance(); | |
//Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory. | |
StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later | |
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized | |
uiServer.attach(statsStorage); | |
//Then add the StatsListener to collect this information from the network, as it trains | |
model.setListeners(new PerformanceListener(30),new ScoreIterationListener(30),new StatsListener(statsStorage)); | |
for (int epochs = 0; epochs < numEpochs; epochs++) { | |
log.warn("Epoch "+epochs); | |
for (int i=0; i <numSamples; i++) { | |
INDArray input = transformImageToBatch(images[i]); | |
INDArray answer = transformLabelToBatch(labels[i]); | |
model.fit(new INDArray[]{input}, new INDArray[]{answer}, null, null); | |
} | |
} | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
System.out.println("Elapsed time="+(System.currentTimeMillis()-startTime)/1000.f+"s"); | |
try { | |
ModelSerializer.writeModel((Model)model,new File("modelDL4J_unet"),true); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
System.out.println("Train done."); | |
System.out.println("Let's call exit()"); | |
System.exit(0); | |
} | |
public static void usage() { | |
log.warn("Please call this program with a image filename as argument"); | |
} | |
public static INDArray transformImageToBatch(File... imageFiles) throws IOException { | |
INDArray imageBatch = Nd4j.create(DataType.FLOAT, imageFiles.length, CHANNELS ,HEIGHT,WIDTH); | |
int index = 0; | |
NativeImageLoader loader = new NativeImageLoader(WIDTH,HEIGHT,CHANNELS); | |
for (File imageFile: imageFiles) { | |
INDArray arr = loader.asMatrix(imageFile); | |
// log.warn(arr.shapeInfoToString()); | |
arr = arr.divi(255.0f); | |
imageBatch.putRow(index++, arr); | |
} | |
return imageBatch; | |
} | |
public static INDArray transformLabelToBatch(File... imageFiles) throws IOException { | |
INDArray imageBatch = Nd4j.create(DataType.FLOAT, imageFiles.length, 1 ,HEIGHT,WIDTH); | |
int index = 0; | |
NativeImageLoader nativeImageLoader = new NativeImageLoader(HEIGHT,WIDTH,1); | |
for (File imageFile: imageFiles) { | |
INDArray arr = nativeImageLoader.asMatrix(imageFile); | |
// log.warn(arr.shapeInfoToString()); | |
arr = arr.divi(255.0f); | |
imageBatch.putRow(index++, arr); | |
} | |
return imageBatch; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment