Skip to content

Instantly share code, notes, and snippets.

@orausch
Created October 28, 2019 15:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save orausch/f78035f42940e5f614c04f32cd53a271 to your computer and use it in GitHub Desktop.
Save orausch/f78035f42940e5f614c04f32cd53a271 to your computer and use it in GitHub Desktop.
import java.util.HashMap;
import java.util.Map;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.ListenerVariables;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
public class Test {
public static void main(String[] args) {
SameDiff sd = SameDiff.create();
sd.addListeners(new VarListener());
SDVariable in = sd.placeHolder("input", DataType.FLOAT, 2, 1);
SDVariable weights = sd.var("weights", 2, 2);
SDVariable loss = weights.mmul(in).sum("loss");
sd.setLossVariables(loss);
Map<String, INDArray> placeholders = new HashMap<>();
placeholders.put("input", Nd4j.ones(2, 1));
sd.execBackwards(placeholders);
// this hack works
System.out.println("From grad: " + sd.getFunction("grad").getArrForVarName("loss"));
}
}
class VarListener implements Listener {
@Override
public ListenerVariables requiredVariables(SameDiff sd) {
return ListenerVariables.builder().trainingVariables("loss").build();
}
@Override
public boolean isActive(Operation operation) {
return true;
}
@Override
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
INDArray activation) {
System.out.println(varName + activation.toString());
}
@Override
public void epochStart(SameDiff sd, At at) {
}
@Override
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
return null;
}
@Override
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
return null;
}
@Override
public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) {
}
@Override
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
}
@Override
public void operationStart(SameDiff sd, Operation op) {
}
@Override
public void operationEnd(SameDiff sd, Operation op) {
}
@Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
}
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
}
@Override
public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) {
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment