Skip to content

Instantly share code, notes, and snippets.

@AlexDBlack
Created November 13, 2019 03:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexDBlack/68403bf29ea4c4b1e91651af9ef68f6f to your computer and use it in GitHub Desktop.
Save AlexDBlack/68403bf29ea4c4b1e91651af9ef68f6f to your computer and use it in GitHub Desktop.
package org.deeplearning4j;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.ReshapeVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
public class Debug8382 {
@Test
public void test() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(0.01))
.graphBuilder()
.addInputs("input1", "input2", "input3") //can use any label for this
.addLayer("e_1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "input1")
.addLayer("e_2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "input2")
.addLayer("d_1", new DenseLayer.Builder().nIn(1).nOut(5).build(), "input3")
.addVertex("e_1_reshape", new ReshapeVertex(-1, 1, 5), "e_1")
.addVertex("e_2_reshape", new ReshapeVertex(-1, 1, 5), "e_2")
.addVertex("d_1_reshape", new ReshapeVertex(-1, 1, 5), "d_1")
.addVertex("stacking_1", new MergeVertex(), "e_1_reshape", "e_2_reshape", "d_1_reshape")
.addLayer("a_plus_b", new tensors_sum(), "stacking_1")
.addLayer("a_plus_b_square", new tensors_square(), "a_plus_b")
.addLayer("a_square_b_square", new tensors_square(), "stacking_1")
.addLayer("a_sq_plus_b_sq", new tensors_sum(), "a_square_b_square")
.addVertex("2ab", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "a_plus_b_square", "a_sq_plus_b_sq")
.addLayer("ab", new tensors_by_2(), "2ab")
.addLayer("ab_sum", new tensors_sum(), "ab")
.addVertex("2d_out", new ReshapeVertex(-1, 1), "ab_sum")
.setOutputs("2d_out")
.build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
INDArray x = Nd4j.ones(DataType.FLOAT, 1, 1);
INDArray out = cg.outputSingle(x, x, x);
System.out.println(out);
System.out.println(out.shapeInfoToString());
}
public static class tensors_sum extends SameDiffLambdaLayer {
@Override
public SDVariable defineLayer(SameDiff sd, SDVariable x) {
return sd.sum(x, false, 1);
}
}
public static class tensors_square extends SameDiffLambdaLayer {
@Override
public SDVariable defineLayer(SameDiff sd, SDVariable x) {
return sd.math.square(x);
}
}
public static class tensors_by_2 extends SameDiffLambdaLayer {
@Override
public SDVariable defineLayer(SameDiff sd, SDVariable x) {
return x.mul(0.5);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment