Skip to content

Instantly share code, notes, and snippets.

@liweigu
Created March 27, 2020 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save liweigu/bfcbd2a6e612e02ad7b24fee3cfac235 to your computer and use it in GitHub Desktop.
Save liweigu/bfcbd2a6e612e02ad7b24fee3cfac235 to your computer and use it in GitHub Desktop.
package xxxxx;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class SumByRegression {
public static void main(String[] args) {
int batchSize = 100;
int seed = 12345;
Random rng = new Random(seed);
DataSetIterator iterator = getTrainingData(batchSize, rng);
int numInput = 2;
int numOutputs = 1;
int nHidden = 10;
double learningRate = 0.01;
MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
.seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(learningRate, 0.9)).list()
.layer(0, new DenseLayer.Builder()
.nIn(numInput).nOut(nHidden).activation(Activation.SIGMOID).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(nHidden)
.nOut(numOutputs).build())
.build());
net.init();
int nEpochs = 500;
for (int i = 0; i < nEpochs; i++) {
iterator.reset();
while (iterator.hasNext()) {
DataSet ds = iterator.next();
net.fit(ds);
}
}
double x = 0.211111;
double y = 0.3;
INDArray input = Nd4j.create(new double[] { x, y }, new int[] { 1, 2 });
INDArray out = net.output(input, false);
System.out.println("x + y = " + out);
}
// generate data: x , y for feature, z (=x+y) for label.
@SuppressWarnings({ "rawtypes", "unchecked" })
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
int nSamples = 1000;
int MIN_RANGE = 0;
int MAX_RANGE = 3;
double[] sum = new double[nSamples];
double[] input1 = new double[nSamples];
double[] input2 = new double[nSamples];
for (int i = 0; i < nSamples; i++) {
// x
input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
// y
input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
// label
sum[i] = input1[i] + input2[i];
}
INDArray inputNDArray1 = Nd4j.create(input1, new int[] { nSamples, 1 });
INDArray inputNDArray2 = Nd4j.create(input2, new int[] { nSamples, 1 });
INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
INDArray outPut = Nd4j.create(sum, new int[] { nSamples, 1 });
DataSet dataSet = new DataSet(inputNDArray, outPut);
List<DataSet> listDs = dataSet.asList();
Collections.shuffle(listDs);
return new ListDataSetIterator(listDs, batchSize);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment