Skip to content

Instantly share code, notes, and snippets.

@fmorbini
Created September 4, 2018 16:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fmorbini/936a7f1c9b0aa905e3ecf71dc096cb30 to your computer and use it in GitHub Desktop.
Save fmorbini/936a7f1c9b0aa905e3ecf71dc096cb30 to your computer and use it in GitHub Desktop.
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class Test {
public static void main(String[] args) throws Exception {
ComputationGraph net = null;
int hiddenaLayerSize = 5;
String LSTM_LAYER = "lstm";
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(1000)
.gradientNormalization(GradientNormalization.RenormalizeL2PerParamType)
.l2(1e-5)
.dropOut(0.8)
.updater(new Adam(3))
.weightInit(WeightInit.XAVIER).graphBuilder().addInputs("vectors", "ontology")
.addVertex("merge", new MergeVertex(), "vectors", "ontology")
// .addLayer("test", new DropoutLayer.Builder(1).build(), "merge")
.addLayer(LSTM_LAYER,
new Bidirectional(Bidirectional.Mode.CONCAT, new LSTM.Builder()
.nIn(10).nOut(hiddenaLayerSize)
.activation(Activation.TANH)
.dropOut(new GaussianNoise(0.05))
.build())
,"merge")
.addLayer("intentOut",
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(hiddenaLayerSize*2)
.nOut(6).build(),
LSTM_LAYER)
.addLayer("neOut",
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(hiddenaLayerSize*2)
.nOut(4).build(),
LSTM_LAYER)
.setOutputs("intentOut", "neOut").build();
net = new ComputationGraph(conf);
net.init();
List<TrainingListener> listeners = new ArrayList<>();
listeners.add(new ScoreIterationListener(1));
net.setListeners(
(TrainingListener[]) listeners.toArray(new TrainingListener[listeners.size()]));
INDArray[] features = new INDArray[2];
features[0] = Nd4j.create(1, 5, 5);
features[1] = Nd4j.create(1, 5, 5);
INDArray[] labels = new INDArray[2];
labels[0] = Nd4j.create(1, 6, 5);
labels[1] = Nd4j.create(1, 4, 5);
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(features, labels);
net.fit(mds);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment