Skip to content

Instantly share code, notes, and snippets.

@SchmaR
Last active January 14, 2019 18:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SchmaR/9d691a518a182543a5ba970c8d24dbac to your computer and use it in GitHub Desktop.
Save SchmaR/9d691a518a182543a5ba970c8d24dbac to your computer and use it in GitHub Desktop.
@Test
public void sameDiffExperimentDistancesFail(){
INDArray defaultTestCase = Nd4j.create(4, 4);
defaultTestCase.putRow(0, Nd4j.create(new float[]{0,2,-2,0}));
defaultTestCase.putRow(1, Nd4j.create(new float[]{0,1,-1,0}));
defaultTestCase.putRow(2, Nd4j.create(new float[]{0,-1,1,0}));
defaultTestCase.putRow(3, Nd4j.create(new float[]{0,-2,2,0}));
long singleEmbeddingSize = defaultTestCase.size(1) / 2L;
// Split vectors
INDArray x = defaultTestCase.get(NDArrayIndex.all(), NDArrayIndex.interval(0, singleEmbeddingSize));
INDArray y = defaultTestCase.get(NDArrayIndex.all(), NDArrayIndex.interval(singleEmbeddingSize, defaultTestCase.size(1)));
log.info(y.shapeInfoToString());
SameDiff sd = SameDiff.create();
sd.enableDebugMode();
SDVariable xSd = sd.var("x", x);
SDVariable ySd = sd.var("y", y);
ySd = ySd.add( ySd);
sd.euclideanDistance( "euclidean",ySd, xSd,0);
StructurePreservingEmbeddingLossTest.log.info(sd.summary());
sd.exec(Collections.emptyMap(), Lists.newArrayList("euclidean"));
sd.execBackwards(Collections.emptyMap());
}
@Test
public void sameDiffExperimentDistances(){
INDArray defaultTestCase = Nd4j.create(4, 4);
defaultTestCase.putRow(0, Nd4j.create(new float[]{0,2,-2,0}));
defaultTestCase.putRow(1, Nd4j.create(new float[]{0,1,-1,0}));
defaultTestCase.putRow(2, Nd4j.create(new float[]{0,-1,1,0}));
defaultTestCase.putRow(3, Nd4j.create(new float[]{0,-2,2,0}));
long singleEmbeddingSize = defaultTestCase.size(1) / 2L;
// Split vectors
INDArray x = defaultTestCase.get(NDArrayIndex.all(), NDArrayIndex.interval(0, singleEmbeddingSize));
INDArray y = defaultTestCase.get(NDArrayIndex.all(), NDArrayIndex.interval(singleEmbeddingSize, defaultTestCase.size(1)));
log.info(y.shapeInfoToString());
SameDiff sd = SameDiff.create();
sd.enableDebugMode();
SDVariable xSd = sd.var("x", x);
SDVariable ySd = sd.var("y", y);
//ySd = ySd.add( ySd);
sd.euclideanDistance( "euclidean",ySd, xSd,0);
StructurePreservingEmbeddingLossTest.log.info(sd.summary());
sd.exec(Collections.emptyMap(), Lists.newArrayList("euclidean"));
sd.execBackwards(Collections.emptyMap());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment