Skip to content

Instantly share code, notes, and snippets.

@AbdelmajidB
Last active May 14, 2019 01:02
Show Gist options
  • Save AbdelmajidB/7551c7e1f33f21893647840bef708946 to your computer and use it in GitHub Desktop.
Save AbdelmajidB/7551c7e1f33f21893647840bef708946 to your computer and use it in GitHub Desktop.
public ComputationGraphConfiguration.GraphBuilder unetBuilder() {
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(updater)
.weightInit(weightInit)
.l2(5e-5)
.miniBatch(true)
.cacheMode(cacheMode)
.trainingWorkspaceMode(workspaceMode)
.inferenceWorkspaceMode(workspaceMode)
.graphBuilder();
graph
.addLayer("conv1-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(32).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "input")
.addLayer("conv1-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv1-1")
.addLayer("pool1", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
.build(), "conv1-2")
.addLayer("conv2-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool1")
.addLayer("conv2-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv2-1")
.addLayer("pool2", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
.build(), "conv2-2")
.addLayer("conv3-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool2")
.addLayer("conv3-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv3-1")
.addLayer("drop3", new DropoutLayer.Builder(0.5).build(), "conv3-2")
.addLayer("pool3", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
.build(), "drop3")
.addLayer("conv4-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool3")
.addLayer("conv4-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(512).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv4-1")
.addLayer("drop4", new DropoutLayer.Builder(0.5).build(), "conv4-2")
// up5
.addLayer("up5-1", new Upsampling3D.Builder(2).build(), "drop4")
.addLayer("up5-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(512).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up5-1")
.addVertex("merge5", new MergeVertex(), "drop3", "up5-2")
.addLayer("conv5-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge5")
.addLayer("conv5-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv5-1")
// up6
.addLayer("up6-1", new Upsampling3D.Builder(2).build(), "conv5-2")
.addLayer("up6-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(256).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up6-1")
.addVertex("merge6", new MergeVertex(), "conv2-2", "up6-2")
.addLayer("conv6-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge6")
.addLayer("conv6-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv6-1")
// up7
.addLayer("up7-1", new Upsampling3D.Builder(2).build(), "conv6-2")
.addLayer("up7-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(128).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up7-1")
.addVertex("merge7", new MergeVertex(), "conv1-2", "up7-2")
.addLayer("conv7-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge7")
.addLayer("conv7-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv7-1")
.addLayer("conv7-3", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(2)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv7-2")
.addLayer("conv8", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(1).dataFormat(Convolution3D.DataFormat.NCDHW)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.IDENTITY).build(), "conv7-3")
.addLayer("output", new Cnn3DLossLayer.Builder(DataFormat.NCDHW).lossFunction(LossFunctions.LossFunction.XENT)
.activation(Activation.SIGMOID).build(), "conv8")
.setOutputs("output");
return graph;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment