Skip to content

Instantly share code, notes, and snippets.

private static final Logger log = LoggerFactory.getLogger(SemanticSegmentation.class);
private static int height = 96;
private static int width = 96;
private static int channels = 1;
private static int batchSize = 22;
private static long seed = 1234;
private static Random rng = new Random(seed);
private static int epochs = 1;
public static final String DATA_PATH= "C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources";
package ma.enset.brain_tumor_segmentation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer;
public class SemanticSegmentation {
private static final Logger log = LoggerFactory.getLogger(SemanticSegmentation.class);
private static int height = 256;
private static int width = 256;
private static int channels = 1;
private static int batchSize = 6;
private static long seed = 1234;
private static Random rng = new Random(seed);
private static int epochs = 100;
21:05:24.437 [main] INFO ma.enset.brain_tumor_segmentation.MyPathLabelGenerator - C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources/testI/rois/1.png
21:05:24.557 [main] INFO ma.enset.brain_tumor_segmentation.SemanticSegmentationLoad2 -
========================Evaluation Metrics========================
# of classes: 2
Accuracy: 0,9990
Precision: 0,9318
Recall: 1,0000
F1 Score: 0,9647
Precision, recall & F1: reported for positive class (class 1 - "1") only
21:03:21.537 [main] INFO ma.enset.brain_tumor_segmentation.MyPathLabelGenerator - C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources/testI/rois/1.png
21:03:21.540 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 2
21:03:21.677 [main] INFO ma.enset.brain_tumor_segmentation.SemanticSegmentationLoad2 -
========================Evaluation Metrics========================
# of classes: 2
Accuracy: 0,1186
Precision: 0,0149
Recall: 1,0000
F1 Score: 0,0294
package ma.enset.brain_tumor_segmentation;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map;
package ma.enset.brain_tumor_segmentation;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map;
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam(0.0001))
.weightInit(new NormalDistribution(0.0, 0.01))
.biasInit(0)
.miniBatch(true)
.cacheMode(cacheMode)
.trainingWorkspaceMode(workspaceMode)
.inferenceWorkspaceMode(workspaceMode)
.graphBuilder();
public class MyPathLabelGenerator implements PathLabelGenerator{
protected static final Logger log = LoggerFactory.getLogger(MyPathLabelGenerator.class);
String labelsDir;
private static NativeImageLoader imageLoader = new NativeImageLoader(96, 96, 1);
File file;
public MyPathLabelGenerator(String path) {
labelsDir=path;
}
while (Iter.hasNext()) {
DataSet next = Iter.next();
INDArray out2d = modelT.outputSingle(next.getFeatures()).permute(0,2,3,1).dup().reshape('c',height*width,1);
INDArray labels2d = next.getLabels().permute(0,2,3,1).dup().reshape('c',height*width,1);
if(k==0) {
e.eval(labels2d, out2d);
log.info(e.stats());
}
k++;
}