Created
November 24, 2018 04:17
-
-
Save AlexDBlack/339991f8da0a5553223ccaaa690c4507 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.deeplearning4j.nn.api.Layer; | |
import org.deeplearning4j.nn.conf.WorkspaceMode; | |
import org.deeplearning4j.nn.conf.dropout.IDropout; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; | |
import org.deeplearning4j.nn.transferlearning.TransferLearning; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.zoo.ZooModel; | |
import org.deeplearning4j.zoo.model.ResNet50; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Nesterovs; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
public class Debug6756 { | |
public static void main(String[] args) throws Exception { | |
long seed = 12345; | |
int height = 112; | |
int width = 112; | |
int channels = 3; | |
int numLabels = 1000; | |
ZooModel zooModel = ResNet50.builder().workspaceMode(WorkspaceMode.NONE).build(); | |
ComputationGraph resnet50 = (ComputationGraph) zooModel.initPretrained(); | |
System.out.println(resnet50.summary()); | |
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() | |
.activation(Activation.LEAKYRELU) | |
.weightInit(WeightInit.RELU) | |
.updater(new Nesterovs(5e-5)) | |
.dropOut(0.5) | |
.seed(seed) | |
.build(); | |
ComputationGraph graph = new TransferLearning.GraphBuilder(resnet50) | |
.fineTuneConfiguration(fineTuneConf) | |
.setInputTypes(InputType.convolutional(height, width, channels)) | |
.removeVertexKeepConnections("conv1") | |
.addLayer("conv1", new ConvolutionLayer.Builder(new int[]{3, 3}) | |
.nIn(channels).nOut(64).activation( Activation.RELU).build(), "input_1") | |
.addLayer("fc2048",new DenseLayer.Builder().activation(Activation.TANH).nIn(1000).nOut(2048).build(),"fc1000") | |
.removeVertexAndConnections("output") | |
.addLayer("newOutput",new OutputLayer | |
.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.activation(Activation.SOFTMAX) | |
.nIn(2048) | |
.nOut(numLabels) | |
.build(),"fc2048") | |
.setOutputs("newOutput") | |
.build(); | |
System.out.println(graph.summary()); | |
///////////////////////////////////////////////////// | |
//WORKAROUND FOR 1.0.0-beta3 | |
for(Layer l : graph.getLayers()){ | |
org.deeplearning4j.nn.conf.layers.Layer conf = (org.deeplearning4j.nn.conf.layers.Layer)l.getConfig(); | |
if(!(l.getConfig() instanceof org.deeplearning4j.nn.conf.layers.BaseLayer)){ | |
conf.setIDropout(null); | |
} else { | |
IDropout d = conf.getIDropout(); | |
conf.setIDropout(d == null ? null : d.clone()); | |
} | |
} | |
///////////////////////////////////////////////////// | |
INDArray in = Nd4j.rand(new int[]{1, channels, height, width}); | |
INDArray labels = Nd4j.create(1, numLabels); | |
labels.putScalar(0, 0, 1.0); | |
DataSet ds = new DataSet(in, labels); | |
graph.fit(ds); | |
} | |
} |
P.S., Yes to come confirm that it works without the above bug fix requirement in 1.0.0-beta. Many thanks, Hamaad.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Alex: Has this bug in 1.0.0-beta3 been fixed or should I switch to 1.0.0-beta? Best wishes, Hamaad.