-
-
Save SchmaR/c8577e9e546e63b5263f43a27ba3f7a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package de.datexis.loss; | |
import org.jetbrains.annotations.NotNull; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
import org.nd4j.autodiff.samediff.VariableType; | |
import org.nd4j.linalg.activations.IActivation; | |
import org.nd4j.linalg.api.buffer.DataType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.indexing.NDArrayIndex; | |
import org.nd4j.linalg.lossfunctions.ILossFunction; | |
import org.nd4j.linalg.primitives.Pair; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.Map; | |
public class TestLoss implements ILossFunction { | |
protected static final Logger log = LoggerFactory.getLogger(TestLoss.class); | |
private boolean isInitialized = false; | |
private SameDiff graph; | |
public TestLoss() { | |
} | |
@Override | |
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { | |
INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); | |
double score = scoreArr.sumNumber().doubleValue(); | |
if (average) { | |
score /= scoreArr.size(0); | |
} | |
return score; | |
} | |
private void applyMask(INDArray mask, INDArray scoreArr) { | |
if (mask != null) { | |
scoreArr.muliColumnVector(mask); | |
} | |
} | |
@Override | |
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { | |
INDArray scoreArray = scoreArray(labels, preOutput, activationFn, mask); | |
return scoreArray.sum(1); | |
} | |
public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { | |
INDArray scoreArr = activationFn.getActivation(preOutput, true); | |
long halfSize = scoreArr.size(1) / 2L; | |
// Split vectors | |
INDArray x = scoreArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, halfSize)); | |
INDArray y = scoreArr.get(NDArrayIndex.all(), NDArrayIndex.interval(halfSize, scoreArr.size(1))); | |
log.info("init graph forward pass"); | |
Map<String, INDArray> variables = initGraph(x, y); | |
log.info("execute graph forward pass"); | |
Map<String, INDArray> scores = graph.execAll(variables); | |
scoreArr = scores.get("scores"); | |
log.info("finished executing graph forward pass"); | |
//multiply with masks, always | |
applyMask(mask, scoreArr); | |
return scoreArr.broadcast(preOutput.shape()[0], 2); | |
} | |
@NotNull | |
private Map<String, INDArray> initGraph(INDArray x, INDArray y) { | |
if (!isInitialized) { | |
log.info("building graph"); | |
graph = buildSameDiffGraph(x.shape(), y.shape()); | |
isInitialized = true; | |
} | |
Map<String, INDArray> variables = new HashMap<>(8); | |
variables.put("x", x); | |
variables.put("y", y); | |
return variables; | |
} | |
@NotNull | |
private SameDiff buildSameDiffGraph(long[] xInShape, long[] yInShape) { | |
SameDiff graph = SameDiff.create(); | |
SDVariable x = graph.placeHolder("x", DataType.FLOAT, xInShape); | |
SDVariable y = graph.placeHolder("y", DataType.FLOAT, yInShape); | |
x.add("scores", y); | |
return graph; | |
} | |
@Override | |
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { | |
INDArray output = activationFn.getActivation(preOutput.dup(), true); | |
long halfSize = output.size(1) / 2L; | |
// Split vectors | |
INDArray x = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, halfSize)); | |
INDArray y = output.get(NDArrayIndex.all(), NDArrayIndex.interval(halfSize, output.size(1))); | |
log.info("init graph backward pass"); | |
Map<String, INDArray> variables = initGraph(x, y); | |
log.info("execute graph backward pass"); | |
graph.createGradFunction(); | |
graph.execBackwards(variables, | |
Arrays.asList(graph.getVariable("x").getGradient().getVarName(), | |
graph.getVariable("y").getGradient().getVarName())); | |
SameDiff gradFn = graph.getFunction("grad"); | |
INDArray dlDx = gradFn.getGradForVariable("x").getArr(); | |
INDArray dlDy = gradFn.getGradForVariable("y").getArr(); | |
//Everything below remains the same | |
INDArray derivative = Nd4j.concat(1, dlDx, dlDy); | |
output = activationFn.backprop(preOutput.dup(), derivative).getFirst(); | |
//multiply with masks, always | |
applyMask(mask, output); | |
return output; | |
} | |
@Override | |
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { | |
return new Pair<>( | |
computeScore(labels, preOutput, activationFn, mask, average), | |
computeGradient(labels, preOutput, activationFn, mask)); | |
} | |
@Override | |
public String name() { | |
return this.getClass().getSimpleName(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment