Created
November 13, 2019 03:38
-
-
Save AlexDBlack/68403bf29ea4c4b1e91651af9ef68f6f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; | |
import org.deeplearning4j.nn.conf.graph.MergeVertex; | |
import org.deeplearning4j.nn.conf.graph.ReshapeVertex; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; | |
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.junit.Test; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
import org.nd4j.linalg.api.buffer.DataType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Sgd; | |
public class Debug8382 { | |
@Test | |
public void test() { | |
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | |
.updater(new Sgd(0.01)) | |
.graphBuilder() | |
.addInputs("input1", "input2", "input3") //can use any label for this | |
.addLayer("e_1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "input1") | |
.addLayer("e_2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "input2") | |
.addLayer("d_1", new DenseLayer.Builder().nIn(1).nOut(5).build(), "input3") | |
.addVertex("e_1_reshape", new ReshapeVertex(-1, 1, 5), "e_1") | |
.addVertex("e_2_reshape", new ReshapeVertex(-1, 1, 5), "e_2") | |
.addVertex("d_1_reshape", new ReshapeVertex(-1, 1, 5), "d_1") | |
.addVertex("stacking_1", new MergeVertex(), "e_1_reshape", "e_2_reshape", "d_1_reshape") | |
.addLayer("a_plus_b", new tensors_sum(), "stacking_1") | |
.addLayer("a_plus_b_square", new tensors_square(), "a_plus_b") | |
.addLayer("a_square_b_square", new tensors_square(), "stacking_1") | |
.addLayer("a_sq_plus_b_sq", new tensors_sum(), "a_square_b_square") | |
.addVertex("2ab", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "a_plus_b_square", "a_sq_plus_b_sq") | |
.addLayer("ab", new tensors_by_2(), "2ab") | |
.addLayer("ab_sum", new tensors_sum(), "ab") | |
.addVertex("2d_out", new ReshapeVertex(-1, 1), "ab_sum") | |
.setOutputs("2d_out") | |
.build(); | |
ComputationGraph cg = new ComputationGraph(conf); | |
cg.init(); | |
INDArray x = Nd4j.ones(DataType.FLOAT, 1, 1); | |
INDArray out = cg.outputSingle(x, x, x); | |
System.out.println(out); | |
System.out.println(out.shapeInfoToString()); | |
} | |
public static class tensors_sum extends SameDiffLambdaLayer { | |
@Override | |
public SDVariable defineLayer(SameDiff sd, SDVariable x) { | |
return sd.sum(x, false, 1); | |
} | |
} | |
public static class tensors_square extends SameDiffLambdaLayer { | |
@Override | |
public SDVariable defineLayer(SameDiff sd, SDVariable x) { | |
return sd.math.square(x); | |
} | |
} | |
public static class tensors_by_2 extends SameDiffLambdaLayer { | |
@Override | |
public SDVariable defineLayer(SameDiff sd, SDVariable x) { | |
return x.mul(0.5); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment