Skip to content

Instantly share code, notes, and snippets.

@gfrison
Last active October 19, 2018 12:11
Show Gist options
  • Save gfrison/cdb4d831488eedc551f9387639bb7b2b to your computer and use it in GitHub Desktop.
Save gfrison/cdb4d831488eedc551f9387639bb7b2b to your computer and use it in GitHub Desktop.
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.slf4j.Logger;
import java.util.Arrays;
import java.util.Random;
import java.util.function.Function;
import java.util.function.Supplier;
import asgard.math.ClippedNormal;
import static asgard.Utils.f;
import static java.lang.Math.E;
import static java.lang.Math.PI;
import static java.lang.Math.log;
import static java.lang.Math.min;
import static java.lang.Math.pow;
import static java.util.stream.IntStream.range;
import static org.slf4j.LoggerFactory.getLogger;
public class YA2C {
private static final Logger log = getLogger(YA2C.class);
static ComputationGraph net = new ComputationGraph(new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(.005))
// .updater(new Nesterovs(.005, .95))
.graphBuilder()
.addInputs("input")
.addLayer("0", new DenseLayer.Builder()
.activation(Activation.RELU)
.nIn(1).nOut(16).build(), "input")
.addLayer("1", new DenseLayer.Builder()
.activation(Activation.RELU)
.nIn(16).nOut(16).build(), "0")
.addLayer("mu", new DenseLayer.Builder()
.activation(Activation.IDENTITY)
.nIn(16).nOut(1).build(), "1")
.addLayer("sigma", new DenseLayer.Builder()
.activation(Activation.SOFTPLUS)
.nIn(16).nOut(1).build(), "1")
.setOutputs("mu", "sigma")
// .setOutputs("mu")
.backprop(true).pretrain(false)
.build());
static Random rnd = new Random();
//X random generator min 0 max 5
static Supplier<Double> xgen = () -> rnd.nextDouble() * 5;
//Y function. sinusoid min 0 max 4
static Function<Double, Double> yGen = x -> 2 + 2 * Math.sin(2 * x);
static class Val {
private final double mu;
private final double sigma;
private final double output;
private final double reward;
Val(double mu, double sigma, double output, double reward) {
this.mu = mu;
this.sigma = sigma;
this.output = output;
this.reward = reward;
}
public double getMu() {
return mu;
}
public double getSigma() {
return sigma;
}
public double getOutput() {
return output;
}
public double getReward() {
return reward;
}
@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.NO_CLASS_NAME_STYLE)
.append("mu", f(mu))
.append("sigma", f(sigma))
.append("output", f(output))
.append("reward", f(reward))
.toString();
}
}
public static void main(String[] args) {
net.setListeners(new ScoreIterationListener(10));
net.init();
int minibatch = 100;
range(0, 1000).forEach(epoch -> {
double[] inputs = range(0, minibatch).mapToDouble(i -> xgen.get()).toArray();
INDArray input = Nd4j.create(new int[]{minibatch, 1}).putColumn(1, Nd4j.create(inputs));
INDArray[] outs = net.output(true, false, input);
//calculate rewards
Val[] rewards = range(0, minibatch)
.mapToObj(i -> {
var mu = outs[0].getDouble(i, 0);
var sigma = outs[1].getDouble(i, 0) + 1e-5;
//create clipped normal dist bounded to 0-4
var output = new ClippedNormal(mu, sigma, 0, 4).sample();
//get the real value for the given input
var real = yGen.apply(inputs[i]);
//rewarding close-to-real outputs
final double reward = min(log(1 / pow(real - output, 2)), 10);
return new Val(mu, sigma, output, reward);
}).toArray(Val[]::new);
log.info("avg reward: %.2f\n", Arrays.stream(rewards).mapToDouble(Val::getReward).average().getAsDouble());
INDArray externalError = Nd4j.create(new int[]{minibatch, 1})
.putColumn(0, Nd4j.create(Arrays.stream(rewards)
.mapToDouble(val -> {
var norm = new NormalDistribution(val.getMu(), val.getSigma());
final double prob = norm.density(val.getOutput());
var lg = -Math.log(prob);
//entropy for normal distribution
var entropy = Math.log(2 * PI * E * pow(val.getSigma(), 2));
// var entropy = -0.5 * (Math.log(2. * PI * pow(val.getSigma(), 2)) + 1);
final double error = -.01 * (lg * val.getReward() - .001 * entropy);
log.info("error:{}, entropy:{}, prob:{}, lg:{}, {}", f(error), f(entropy), f(prob), f(lg), val);
return error;
}).toArray()));
//Calculate gradient with respect to an external reward
Gradient gradient = net.backpropGradient(externalError, externalError); //Calculate backprop gradient based on reward array
//Update the gradient: apply learning rate, momentum, etc
//This modifies the Gradient object in-place
range(0, 10).forEach(iteration -> {
net.getUpdater().update(gradient, iteration, epoch, minibatch, LayerWorkspaceMgr.noWorkspaces());
//Get a row vector gradient array, and apply it to the parameters to update the model
INDArray updateVector = gradient.gradient();
net.params().subi(updateVector);
});
});
for (int i = 0; i < 3; i++) {
// final INDArray input = Nd4j.create(new double[]{1, 2}, new int[]{1, 2});
final Double x = xgen.get();
var y = yGen.apply(x);
final INDArray input = Nd4j.create(new double[]{x});
INDArray[] out = net.output(false, input);
log.info("x:{}, y:{}, predicted mu:{}, sigma:{}", x, y, out[0].getDouble(0), out[1].getDouble(0));
// System.out.println("input:" + input + ", output:" + Arrays.toString(out) + ", correct:" + range(0, nIn).mapToDouble(y -> input.getDouble(0, y)).sum());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment