Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created August 21, 2020 12:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agibsonccc/24f9a418dcc5cdaeba8299ea3096e945 to your computer and use it in GitHub Desktop.
Save agibsonccc/24f9a418dcc5cdaeba8299ea3096e945 to your computer and use it in GitHub Desktop.
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.List;
public class SpamClassifier {
private static MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam(0.0001))
.list()
.layer(0, new DenseLayer.Builder()
.nIn(512).nOut(600)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new DropoutLayer.Builder().dropOut(0.8)
.activation(Activation.SIGMOID).build())
.layer(2, new DenseLayer.Builder()
.nIn(600).nOut(300)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(3, new DropoutLayer.Builder().dropOut(0.8)
.activation(Activation.SIGMOID).build())
.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(300).nOut(2)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();
private static MultiLayerNetwork model = new MultiLayerNetwork(conf);
public static void runImport() throws Exception {
MultiLayerNetwork computationGraph = KerasModelImport.importKerasSequentialModelAndWeights("test-spam.hdf5");
INDArray xTrain = Nd4j.readNpy(new File("X_train.npy"));
INDArray xTest = Nd4j.readNpy(new File("X_test.npy"));
INDArray yTrain = Nd4j.readNpy(new File("y_train.npy"));
INDArray yTest = Nd4j.readNpy(new File("y_test.npy"));
DataSet train = new org.nd4j.linalg.dataset.DataSet(xTrain,yTrain);
DataSet test = new org.nd4j.linalg.dataset.DataSet(xTest,yTest);
List<org.nd4j.linalg.dataset.DataSet> trainingDataList = train.batchBy(100);
DataSetIterator dataSetIterator = new ListDataSetIterator<org.nd4j.linalg.dataset.DataSet>(trainingDataList);
model.init();
model.setListeners(new ScoreIterationListener(1));
for(int i = 0; i < 20; i++)
model.fit(dataSetIterator);
INDArray output = model.output(test.getFeatures(), Layer.TrainingMode.TEST);
Evaluation eval = new Evaluation(2);
eval.eval(test.getLabels(), output);
System.out.println(eval.stats());
}
public static void run() throws Exception {
RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new FileSplit(new ClassPathResource("spam.csv").getFile()));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 5171, 512, 2);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.75);
DataSet trainingData = testAndTrain.getTrain();
List<org.nd4j.linalg.dataset.DataSet> trainingDataList = trainingData.batchBy(100);
DataSetIterator dataSetIterator = new ListDataSetIterator<org.nd4j.linalg.dataset.DataSet>(trainingDataList);
DataSet testData = testAndTrain.getTest();
model.init();
model.setListeners(new ScoreIterationListener(1));
for(int i = 0; i < 20; i++)
model.fit(dataSetIterator);
INDArray output = model.output(testData.getFeatures(), Layer.TrainingMode.TEST);
Evaluation eval = new Evaluation(2);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
}
public static void main(String[] args) throws Exception {
Nd4j.setDefaultDataTypes(DataType.DOUBLE,DataType.DOUBLE);
run();
runImport();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment