Skip to content

Instantly share code, notes, and snippets.

@SchmaR
Created February 6, 2019 16:08
Show Gist options
  • Save SchmaR/c8577e9e546e63b5263f43a27ba3f7a1 to your computer and use it in GitHub Desktop.
Save SchmaR/c8577e9e546e63b5263f43a27ba3f7a1 to your computer and use it in GitHub Desktop.
package de.datexis.loss;
import org.jetbrains.annotations.NotNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class TestLoss implements ILossFunction {
protected static final Logger log = LoggerFactory.getLogger(TestLoss.class);
private boolean isInitialized = false;
private SameDiff graph;
public TestLoss() {
}
@Override
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask);
double score = scoreArr.sumNumber().doubleValue();
if (average) {
score /= scoreArr.size(0);
}
return score;
}
private void applyMask(INDArray mask, INDArray scoreArr) {
if (mask != null) {
scoreArr.muliColumnVector(mask);
}
}
@Override
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray scoreArray = scoreArray(labels, preOutput, activationFn, mask);
return scoreArray.sum(1);
}
public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray scoreArr = activationFn.getActivation(preOutput, true);
long halfSize = scoreArr.size(1) / 2L;
// Split vectors
INDArray x = scoreArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, halfSize));
INDArray y = scoreArr.get(NDArrayIndex.all(), NDArrayIndex.interval(halfSize, scoreArr.size(1)));
log.info("init graph forward pass");
Map<String, INDArray> variables = initGraph(x, y);
log.info("execute graph forward pass");
Map<String, INDArray> scores = graph.execAll(variables);
scoreArr = scores.get("scores");
log.info("finished executing graph forward pass");
//multiply with masks, always
applyMask(mask, scoreArr);
return scoreArr.broadcast(preOutput.shape()[0], 2);
}
@NotNull
private Map<String, INDArray> initGraph(INDArray x, INDArray y) {
if (!isInitialized) {
log.info("building graph");
graph = buildSameDiffGraph(x.shape(), y.shape());
isInitialized = true;
}
Map<String, INDArray> variables = new HashMap<>(8);
variables.put("x", x);
variables.put("y", y);
return variables;
}
@NotNull
private SameDiff buildSameDiffGraph(long[] xInShape, long[] yInShape) {
SameDiff graph = SameDiff.create();
SDVariable x = graph.placeHolder("x", DataType.FLOAT, xInShape);
SDVariable y = graph.placeHolder("y", DataType.FLOAT, yInShape);
x.add("scores", y);
return graph;
}
@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray output = activationFn.getActivation(preOutput.dup(), true);
long halfSize = output.size(1) / 2L;
// Split vectors
INDArray x = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, halfSize));
INDArray y = output.get(NDArrayIndex.all(), NDArrayIndex.interval(halfSize, output.size(1)));
log.info("init graph backward pass");
Map<String, INDArray> variables = initGraph(x, y);
log.info("execute graph backward pass");
graph.createGradFunction();
graph.execBackwards(variables,
Arrays.asList(graph.getVariable("x").getGradient().getVarName(),
graph.getVariable("y").getGradient().getVarName()));
SameDiff gradFn = graph.getFunction("grad");
INDArray dlDx = gradFn.getGradForVariable("x").getArr();
INDArray dlDy = gradFn.getGradForVariable("y").getArr();
//Everything below remains the same
INDArray derivative = Nd4j.concat(1, dlDx, dlDy);
output = activationFn.backprop(preOutput.dup(), derivative).getFirst();
//multiply with masks, always
applyMask(mask, output);
return output;
}
@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
return new Pair<>(
computeScore(labels, preOutput, activationFn, mask, average),
computeGradient(labels, preOutput, activationFn, mask));
}
@Override
public String name() {
return this.getClass().getSimpleName();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment