Created
March 27, 2020 01:42
-
-
Save liweigu/bfcbd2a6e612e02ad7b24fee3cfac235 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package xxxxx; | |
import java.util.Collections; | |
import java.util.List; | |
import java.util.Random; | |
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Nesterovs; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
public class SumByRegression { | |
public static void main(String[] args) { | |
int batchSize = 100; | |
int seed = 12345; | |
Random rng = new Random(seed); | |
DataSetIterator iterator = getTrainingData(batchSize, rng); | |
int numInput = 2; | |
int numOutputs = 1; | |
int nHidden = 10; | |
double learningRate = 0.01; | |
MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Nesterovs(learningRate, 0.9)).list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(numInput).nOut(nHidden).activation(Activation.SIGMOID).build()) | |
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) | |
.activation(Activation.IDENTITY) | |
.nIn(nHidden) | |
.nOut(numOutputs).build()) | |
.build()); | |
net.init(); | |
int nEpochs = 500; | |
for (int i = 0; i < nEpochs; i++) { | |
iterator.reset(); | |
while (iterator.hasNext()) { | |
DataSet ds = iterator.next(); | |
net.fit(ds); | |
} | |
} | |
double x = 0.211111; | |
double y = 0.3; | |
INDArray input = Nd4j.create(new double[] { x, y }, new int[] { 1, 2 }); | |
INDArray out = net.output(input, false); | |
System.out.println("x + y = " + out); | |
} | |
// generate data: x , y for feature, z (=x+y) for label. | |
@SuppressWarnings({ "rawtypes", "unchecked" }) | |
private static DataSetIterator getTrainingData(int batchSize, Random rand) { | |
int nSamples = 1000; | |
int MIN_RANGE = 0; | |
int MAX_RANGE = 3; | |
double[] sum = new double[nSamples]; | |
double[] input1 = new double[nSamples]; | |
double[] input2 = new double[nSamples]; | |
for (int i = 0; i < nSamples; i++) { | |
// x | |
input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); | |
// y | |
input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); | |
// label | |
sum[i] = input1[i] + input2[i]; | |
} | |
INDArray inputNDArray1 = Nd4j.create(input1, new int[] { nSamples, 1 }); | |
INDArray inputNDArray2 = Nd4j.create(input2, new int[] { nSamples, 1 }); | |
INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2); | |
INDArray outPut = Nd4j.create(sum, new int[] { nSamples, 1 }); | |
DataSet dataSet = new DataSet(inputNDArray, outPut); | |
List<DataSet> listDs = dataSet.asList(); | |
Collections.shuffle(listDs); | |
return new ListDataSetIterator(listDs, batchSize); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment