Last active
October 19, 2018 12:11
-
-
Save gfrison/cdb4d831488eedc551f9387639bb7b2b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.commons.lang3.builder.ToStringBuilder; | |
import org.apache.commons.lang3.builder.ToStringStyle; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.gradient.Gradient; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.slf4j.Logger; | |
import java.util.Arrays; | |
import java.util.Random; | |
import java.util.function.Function; | |
import java.util.function.Supplier; | |
import asgard.math.ClippedNormal; | |
import static asgard.Utils.f; | |
import static java.lang.Math.E; | |
import static java.lang.Math.PI; | |
import static java.lang.Math.log; | |
import static java.lang.Math.min; | |
import static java.lang.Math.pow; | |
import static java.util.stream.IntStream.range; | |
import static org.slf4j.LoggerFactory.getLogger; | |
public class YA2C { | |
private static final Logger log = getLogger(YA2C.class); | |
static ComputationGraph net = new ComputationGraph(new NeuralNetConfiguration.Builder() | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Adam(.005)) | |
// .updater(new Nesterovs(.005, .95)) | |
.graphBuilder() | |
.addInputs("input") | |
.addLayer("0", new DenseLayer.Builder() | |
.activation(Activation.RELU) | |
.nIn(1).nOut(16).build(), "input") | |
.addLayer("1", new DenseLayer.Builder() | |
.activation(Activation.RELU) | |
.nIn(16).nOut(16).build(), "0") | |
.addLayer("mu", new DenseLayer.Builder() | |
.activation(Activation.IDENTITY) | |
.nIn(16).nOut(1).build(), "1") | |
.addLayer("sigma", new DenseLayer.Builder() | |
.activation(Activation.SOFTPLUS) | |
.nIn(16).nOut(1).build(), "1") | |
.setOutputs("mu", "sigma") | |
// .setOutputs("mu") | |
.backprop(true).pretrain(false) | |
.build()); | |
static Random rnd = new Random(); | |
//X random generator min 0 max 5 | |
static Supplier<Double> xgen = () -> rnd.nextDouble() * 5; | |
//Y function. sinusoid min 0 max 4 | |
static Function<Double, Double> yGen = x -> 2 + 2 * Math.sin(2 * x); | |
static class Val { | |
private final double mu; | |
private final double sigma; | |
private final double output; | |
private final double reward; | |
Val(double mu, double sigma, double output, double reward) { | |
this.mu = mu; | |
this.sigma = sigma; | |
this.output = output; | |
this.reward = reward; | |
} | |
public double getMu() { | |
return mu; | |
} | |
public double getSigma() { | |
return sigma; | |
} | |
public double getOutput() { | |
return output; | |
} | |
public double getReward() { | |
return reward; | |
} | |
@Override | |
public String toString() { | |
return new ToStringBuilder(this, ToStringStyle.NO_CLASS_NAME_STYLE) | |
.append("mu", f(mu)) | |
.append("sigma", f(sigma)) | |
.append("output", f(output)) | |
.append("reward", f(reward)) | |
.toString(); | |
} | |
} | |
public static void main(String[] args) { | |
net.setListeners(new ScoreIterationListener(10)); | |
net.init(); | |
int minibatch = 100; | |
range(0, 1000).forEach(epoch -> { | |
double[] inputs = range(0, minibatch).mapToDouble(i -> xgen.get()).toArray(); | |
INDArray input = Nd4j.create(new int[]{minibatch, 1}).putColumn(1, Nd4j.create(inputs)); | |
INDArray[] outs = net.output(true, false, input); | |
//calculate rewards | |
Val[] rewards = range(0, minibatch) | |
.mapToObj(i -> { | |
var mu = outs[0].getDouble(i, 0); | |
var sigma = outs[1].getDouble(i, 0) + 1e-5; | |
//create clipped normal dist bounded to 0-4 | |
var output = new ClippedNormal(mu, sigma, 0, 4).sample(); | |
//get the real value for the given input | |
var real = yGen.apply(inputs[i]); | |
//rewarding close-to-real outputs | |
final double reward = min(log(1 / pow(real - output, 2)), 10); | |
return new Val(mu, sigma, output, reward); | |
}).toArray(Val[]::new); | |
log.info("avg reward: %.2f\n", Arrays.stream(rewards).mapToDouble(Val::getReward).average().getAsDouble()); | |
INDArray externalError = Nd4j.create(new int[]{minibatch, 1}) | |
.putColumn(0, Nd4j.create(Arrays.stream(rewards) | |
.mapToDouble(val -> { | |
var norm = new NormalDistribution(val.getMu(), val.getSigma()); | |
final double prob = norm.density(val.getOutput()); | |
var lg = -Math.log(prob); | |
//entropy for normal distribution | |
var entropy = Math.log(2 * PI * E * pow(val.getSigma(), 2)); | |
// var entropy = -0.5 * (Math.log(2. * PI * pow(val.getSigma(), 2)) + 1); | |
final double error = -.01 * (lg * val.getReward() - .001 * entropy); | |
log.info("error:{}, entropy:{}, prob:{}, lg:{}, {}", f(error), f(entropy), f(prob), f(lg), val); | |
return error; | |
}).toArray())); | |
//Calculate gradient with respect to an external reward | |
Gradient gradient = net.backpropGradient(externalError, externalError); //Calculate backprop gradient based on reward array | |
//Update the gradient: apply learning rate, momentum, etc | |
//This modifies the Gradient object in-place | |
range(0, 10).forEach(iteration -> { | |
net.getUpdater().update(gradient, iteration, epoch, minibatch, LayerWorkspaceMgr.noWorkspaces()); | |
//Get a row vector gradient array, and apply it to the parameters to update the model | |
INDArray updateVector = gradient.gradient(); | |
net.params().subi(updateVector); | |
}); | |
}); | |
for (int i = 0; i < 3; i++) { | |
// final INDArray input = Nd4j.create(new double[]{1, 2}, new int[]{1, 2}); | |
final Double x = xgen.get(); | |
var y = yGen.apply(x); | |
final INDArray input = Nd4j.create(new double[]{x}); | |
INDArray[] out = net.output(false, input); | |
log.info("x:{}, y:{}, predicted mu:{}, sigma:{}", x, y, out[0].getDouble(0), out[1].getDouble(0)); | |
// System.out.println("input:" + input + ", output:" + Arrays.toString(out) + ", correct:" + range(0, nIn).mapToDouble(y -> input.getDouble(0, y)).sum()); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment