Skip to content

Instantly share code, notes, and snippets.

@chinproisbestpro
Last active April 1, 2021 02:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chinproisbestpro/b56f3544a5356558a9929aa2ccd9f755 to your computer and use it in GitHub Desktop.
Save chinproisbestpro/b56f3544a5356558a9929aa2ccd9f755 to your computer and use it in GitHub Desktop.
int inputSize = summonerSelectionInputModel.size();
INDArray input = Nd4j.zeros(inputSize, 621).castTo(DataType.DOUBLE);
INDArray labels = Nd4j.zeros(inputSize, 2).castTo(DataType.DOUBLE);
IntStream.range(0, inputSize).forEach(i -> {
SuggestedWinnerV2Model summonerSelect = summonerSelectionInputModel.get(i);
// A list of 621 0's and sparse 1's;
List<Integer> inputVector = summonerSelect.convertInputsToVector(totalCards);
IntStream.range(0, inputVector.size()).forEach(j ->
input.putScalar(new int[]{i, j}, inputVector.get(j))
);
// A list of 2 0's or 1.
List<Integer> labelVector = summonerSelect.convertOutputsToVector();
IntStream.range(0, labelVector.size()).forEach(j ->
labels.putScalar(new int[]{i, j}, labelVector.get(j))
);
});
DataSet ds = new DataSet(input, labels)
ds.shuffle();
SplitTestAndTrain testAndTrain = ds.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);
normalizer.transform(trainingData);
normalizer.transform(testData);
File deckMln = new File(getWinnerV2FileName(manaCap));
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Nesterovs(LEARNING_RATE, MOMENTUM))
.seed(SEED)
.biasInit(0)
.dataType(DataType.DOUBLE)
.miniBatch(true)
.list()
.layer(new DenseLayer.Builder()
.nIn(621)
.nOut(623)
.build())
.layer(new DenseLayer.Builder()
.nIn(623)
.nOut(300)
.build())
.layer(new OutputLayer.Builder(new LossMCXENT())
.nIn(300)
.nOut(2)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < EPOCHS; i++) {
net.fit(trainingData);
}
// PRINT OUT EVAL
INDArray output = net.output(testData.getFeatures()).castTo(DataType.DOUBLE);
INDArray labels = testData.getLabels().castTo(DataType.DOUBLE);
Evaluation eval = new Evaluation();
eval.eval(labels, output);
System.out.println(eval.stats());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment