-
-
Save orausch/f78035f42940e5f614c04f32cd53a271 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.HashMap; | |
import java.util.Map; | |
import org.nd4j.autodiff.listeners.At; | |
import org.nd4j.autodiff.listeners.Listener; | |
import org.nd4j.autodiff.listeners.ListenerResponse; | |
import org.nd4j.autodiff.listeners.ListenerVariables; | |
import org.nd4j.autodiff.listeners.Loss; | |
import org.nd4j.autodiff.listeners.Operation; | |
import org.nd4j.autodiff.listeners.records.LossCurve; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
import org.nd4j.autodiff.samediff.internal.SameDiffOp; | |
import org.nd4j.autodiff.samediff.internal.Variable; | |
import org.nd4j.linalg.api.buffer.DataType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.MultiDataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
public class Test { | |
public static void main(String[] args) { | |
SameDiff sd = SameDiff.create(); | |
sd.addListeners(new VarListener()); | |
SDVariable in = sd.placeHolder("input", DataType.FLOAT, 2, 1); | |
SDVariable weights = sd.var("weights", 2, 2); | |
SDVariable loss = weights.mmul(in).sum("loss"); | |
sd.setLossVariables(loss); | |
Map<String, INDArray> placeholders = new HashMap<>(); | |
placeholders.put("input", Nd4j.ones(2, 1)); | |
sd.execBackwards(placeholders); | |
// this hack works | |
System.out.println("From grad: " + sd.getFunction("grad").getArrForVarName("loss")); | |
} | |
} | |
class VarListener implements Listener { | |
@Override | |
public ListenerVariables requiredVariables(SameDiff sd) { | |
return ListenerVariables.builder().trainingVariables("loss").build(); | |
} | |
@Override | |
public boolean isActive(Operation operation) { | |
return true; | |
} | |
@Override | |
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, | |
INDArray activation) { | |
System.out.println(varName + activation.toString()); | |
} | |
@Override | |
public void epochStart(SameDiff sd, At at) { | |
} | |
@Override | |
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { | |
return null; | |
} | |
@Override | |
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) { | |
return null; | |
} | |
@Override | |
public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) { | |
} | |
@Override | |
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { | |
} | |
@Override | |
public void operationStart(SameDiff sd, Operation op) { | |
} | |
@Override | |
public void operationEnd(SameDiff sd, Operation op) { | |
} | |
@Override | |
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { | |
} | |
@Override | |
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { | |
} | |
@Override | |
public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) { | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment