Skip to content

Instantly share code, notes, and snippets.

@AlexDBlack
Created November 24, 2018 04:17
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexDBlack/339991f8da0a5553223ccaaa690c4507 to your computer and use it in GitHub Desktop.
Save AlexDBlack/339991f8da0a5553223ccaaa690c4507 to your computer and use it in GitHub Desktop.
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.model.ResNet50;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class Debug6756 {
public static void main(String[] args) throws Exception {
long seed = 12345;
int height = 112;
int width = 112;
int channels = 3;
int numLabels = 1000;
ZooModel zooModel = ResNet50.builder().workspaceMode(WorkspaceMode.NONE).build();
ComputationGraph resnet50 = (ComputationGraph) zooModel.initPretrained();
System.out.println(resnet50.summary());
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.activation(Activation.LEAKYRELU)
.weightInit(WeightInit.RELU)
.updater(new Nesterovs(5e-5))
.dropOut(0.5)
.seed(seed)
.build();
ComputationGraph graph = new TransferLearning.GraphBuilder(resnet50)
.fineTuneConfiguration(fineTuneConf)
.setInputTypes(InputType.convolutional(height, width, channels))
.removeVertexKeepConnections("conv1")
.addLayer("conv1", new ConvolutionLayer.Builder(new int[]{3, 3})
.nIn(channels).nOut(64).activation( Activation.RELU).build(), "input_1")
.addLayer("fc2048",new DenseLayer.Builder().activation(Activation.TANH).nIn(1000).nOut(2048).build(),"fc1000")
.removeVertexAndConnections("output")
.addLayer("newOutput",new OutputLayer
.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(2048)
.nOut(numLabels)
.build(),"fc2048")
.setOutputs("newOutput")
.build();
System.out.println(graph.summary());
/////////////////////////////////////////////////////
//WORKAROUND FOR 1.0.0-beta3
for(Layer l : graph.getLayers()){
org.deeplearning4j.nn.conf.layers.Layer conf = (org.deeplearning4j.nn.conf.layers.Layer)l.getConfig();
if(!(l.getConfig() instanceof org.deeplearning4j.nn.conf.layers.BaseLayer)){
conf.setIDropout(null);
} else {
IDropout d = conf.getIDropout();
conf.setIDropout(d == null ? null : d.clone());
}
}
/////////////////////////////////////////////////////
INDArray in = Nd4j.rand(new int[]{1, channels, height, width});
INDArray labels = Nd4j.create(1, numLabels);
labels.putScalar(0, 0, 1.0);
DataSet ds = new DataSet(in, labels);
graph.fit(ds);
}
}
@hamaadshah
Copy link

Hi Alex: Has this bug in 1.0.0-beta3 been fixed or should I switch to 1.0.0-beta? Best wishes, Hamaad.

@hamaadshah
Copy link

P.S., Yes to come confirm that it works without the above bug fix requirement in 1.0.0-beta. Many thanks, Hamaad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment