Created
January 3, 2018 15:25
-
-
Save reuschling/bf82797dcd1e5b343392f2e05cc657eb to your computer and use it in GitHub Desktop.
Origin dl4j UCISequenceClassificationExample with additional code to perfom a classification. The classification is verified by reconstructing the Evaluation class confusion matrix. During this, there are two errors: 1. normalizer.transform has to be double invocated. 2. Evaluation class shows wrong number of featues (sum is bigger than test cas…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dfki.sds.servicefactory.deeplearning; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.net.URL; | |
import java.util.ArrayList; | |
import java.util.Collection; | |
import java.util.Collections; | |
import java.util.List; | |
import java.util.Random; | |
import org.apache.commons.io.FileUtils; | |
import org.apache.commons.io.IOUtils; | |
import org.datavec.api.records.reader.SequenceRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.split.NumberedFileInputSplit; | |
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.WorkspaceMode; | |
import org.deeplearning4j.nn.conf.layers.GravesLSTM; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.indexing.NDArrayIndex; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.nd4j.linalg.primitives.Pair; | |
import org.nd4j.linalg.util.ArrayUtil; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
/** | |
* Sequence Classification Example Using a LSTM Recurrent Neural Network | |
* | |
* This example learns how to classify univariate time series as belonging to one of six categories. Categories are: Normal, Cyclic, Increasing trend, Decreasing trend, | |
* Upward shift, Downward shift | |
* | |
* Data is the UCI Synthetic Control Chart Time Series Data Set Details: https://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series Data: | |
* https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data Image: | |
* https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg | |
* | |
* This example proceeds as follows: 1. Download and prepare the data (in downloadUCIData() method) (a) Split the 600 sequences into train set of size 450, and test set | |
* of size 150 (b) Write the data into a format suitable for loading using the CSVSequenceRecordReader for sequence classification This format: one time series per file, | |
* and a separate file for the labels. For example, train/features/0.csv is the features using with the labels file train/labels/0.csv Because the data is a univariate | |
* time series, we only have one column in the CSV files. Normally, each column would contain multiple values - one time step per row. Furthermore, because we have only | |
* one label for each time series, the labels CSV files contain only a single value | |
* | |
* 2. Load the training data using CSVSequenceRecordReader (to load/parse the CSV files) and SequenceRecordReaderDataSetIterator (to convert it to DataSet objects, ready | |
* to train) For more details on this step, see: http://deeplearning4j.org/usingrnns#data | |
* | |
* 3. Normalize the data. The raw data contain values that are too large for effective training, and need to be normalized. Normalization is conducted using | |
* NormalizerStandardize, based on statistics (mean, st.dev) collected on the training data only. Note that both the training data and test data are normalized in the | |
* same way. | |
* | |
* 4. Configure the network The data set here is very small, so we can't afford to use a large network with many parameters. We are using one small LSTM layer and one RNN | |
* output layer | |
* | |
* 5. Train the network for 40 epochs At each epoch, evaluate and print the accuracy and f1 on the test set | |
* | |
* @author Alex Black | |
*/ | |
public class UCISequenceClassificationExample | |
{ | |
// 'baseDir': Base directory for the data. Change this if you want to save the data somewhere else | |
private static File baseDir = new File("src/main/resources/uci/"); | |
private static File baseTestDir = new File(baseDir, "test"); | |
private static File baseTrainDir = new File(baseDir, "train"); | |
private static File featuresDirTest = new File(baseTestDir, "features"); | |
private static File featuresDirTrain = new File(baseTrainDir, "features"); | |
private static File labelsDirTest = new File(baseTestDir, "labels"); | |
private static File labelsDirTrain = new File(baseTrainDir, "labels"); | |
private static final Logger log = LoggerFactory.getLogger(UCISequenceClassificationExample.class); | |
/** | |
* Adds an Object multiple times to a given collection. This is nice if you want to initially fill e.g. a List with a specific value | |
* | |
* @return collection2fill | |
*/ | |
@SuppressWarnings("javadoc") | |
static private <T> Collection<T> addMultipleTimes(T object2add, int times2add, Collection<T> collection2fill) | |
{ | |
for (int i = 0; i < times2add; i++) | |
collection2fill.add(object2add); | |
return collection2fill; | |
} | |
static private double[][] convertToSimpleDouble2D(Collection<? extends Collection<? extends Object>> col) | |
{ | |
double[][] simple = new double[col.size()][col.iterator().next().size()]; | |
int i = 0; | |
for (Collection<? extends Object> subCol : col) | |
{ | |
int z = 0; | |
for (Object val : subCol) | |
{ | |
if(val instanceof Number) | |
simple[i][z++] = ((Number) val).doubleValue(); | |
else if("false".equals(val.toString())) | |
simple[i][z++] = 0d; | |
else if("true".equals(val.toString())) | |
simple[i][z++] = 1d; | |
else | |
simple[i][z++] = Double.valueOf(val.toString()); | |
} | |
i++; | |
} | |
return simple; | |
} | |
/** | |
* Creates an Array object by using a varargs value | |
* | |
* @param values the values for the new collection object | |
* | |
* @return the new, filled object | |
*/ | |
@SafeVarargs | |
static private <T> T[] createArray(T... values) | |
{ | |
return values; | |
} | |
/** | |
* Returns the content of a file as string. This method uses UTF-8 encoding. | |
* | |
* @param strPath the file path | |
* | |
* @return the file content as string | |
* | |
* @throws Exception | |
*/ | |
static private String file2String(String strPath) throws Exception | |
{ | |
return file2String(strPath, "UTF-8"); | |
} | |
/** | |
* Returns the content of a file as string. The character encoding can be specified. Possible values are e.g. <br> | |
* US-ASCII Seven-bit ASCII, a.k.a. ISO646-US, a.k.a. the Basic Latin block of the Unicode character set <br> | |
* ISO-8859-1 ISO Latin Alphabet No. 1, a.k.a. ISO-LATIN-1 <br> | |
* UTF-8 Eight-bit UCS Transformation Format <br> | |
* UTF-16BE Sixteen-bit UCS Transformation Format, big-endian byte order <br> | |
* UTF-16LE Sixteen-bit UCS Transformation Format, little-endian byte order <br> | |
* UTF-16 Sixteen-bit UCS Transformation Format, byte order identified by an optional byte-order mark <br> | |
* | |
* @param strPath the file path | |
* @param strCharEncoding the character encoding as String | |
* @return the file content as string | |
* @throws IOException | |
*/ | |
static private String file2String(String strPath, String strCharEncoding) throws IOException | |
{ | |
File file = new File(strPath).getAbsoluteFile(); | |
byte[] bytes = getBytesFromFile(file); | |
String strContent = new String(bytes, strCharEncoding); | |
return strContent; | |
} | |
// Returns the contents of the file in a byte array. | |
private static byte[] getBytesFromFile(File file) throws IOException | |
{ | |
InputStream is = new FileInputStream(file); | |
// Get the size of the file | |
long length = file.length(); | |
// You cannot create an array using a long type. | |
// It needs to be an int type. | |
// Before converting to an int type, check | |
// to ensure that file is not larger than Integer.MAX_VALUE. | |
if(length > Integer.MAX_VALUE) | |
{ | |
// File is too large | |
} | |
// Create the byte array to hold the data | |
byte[] bytes = new byte[(int) length]; | |
// Read in the bytes | |
int offset = 0; | |
int numRead = 0; | |
while (offset < bytes.length && (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) | |
{ | |
offset += numRead; | |
} | |
// Close the input stream and return bytes | |
is.close(); | |
// Ensure all the bytes have been read in | |
if(offset < bytes.length) | |
{ | |
throw new IOException("Could not completely read file " + file.getName()); | |
} | |
return bytes; | |
} | |
public static void main(String[] args) throws Exception | |
{ | |
downloadUCIData(); | |
// ----- Load the training data ----- | |
// Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv | |
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); | |
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); | |
SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); | |
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); | |
int miniBatchSize = 10; | |
int numLabelClasses = 6; | |
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, | |
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
// Normalize the training data | |
DataNormalization normalizer = new NormalizerStandardize(); | |
normalizer.fit(trainData); // Collect training data statistics | |
trainData.reset(); | |
// Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized | |
trainData.setPreProcessor(normalizer); | |
// ----- Load the test data ----- | |
// Same process as for the training data. | |
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); | |
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); | |
SequenceRecordReader testLabels = new CSVSequenceRecordReader(); | |
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); | |
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses, false, | |
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
testData.setPreProcessor(normalizer); // Note that we are using the exact same normalization process as the training data | |
// @formatter:off | |
// ----- Configure the network ----- | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) // Random number generator seed for improved repeatability. Optional. | |
//XXX modified from me, setting the workspace | |
.trainingWorkspaceMode(WorkspaceMode.SINGLE)//.inferenceWorkspaceMode(WorkspaceMode.SINGLE) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) | |
.weightInit(WeightInit.XAVIER) | |
.updater(Updater.NESTEROVS) | |
.learningRate(0.005) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) // Not always required, but helps with this data set | |
.gradientNormalizationThreshold(0.5) | |
.list() | |
.layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build()) | |
.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build()) | |
.pretrain(false).backprop(true).build(); | |
// @formatter:on | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
net.setListeners(new ScoreIterationListener(20)); // Print the score (loss function value) every 20 iterations | |
// ----- Train the network, evaluating the test set performance at each epoch ----- | |
int nEpochs = 5; | |
String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f"; | |
for (int i = 0; i < nEpochs; i++) | |
{ | |
net.fit(trainData); | |
// Evaluate on the test set: | |
Evaluation evaluation = net.evaluate(testData); | |
log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()) | |
+ ", Confusion matrix from Evaluation class (shows wrong number of cases: 151 vs 150 (correct)\n" + evaluation.confusionToString()); | |
testData.reset(); | |
trainData.reset(); | |
} | |
log.info("----- Example Complete -----"); | |
classify(net, normalizer); | |
} | |
public static int maxValueIndex(INDArray softmaxOutput) | |
{ | |
double dMax = Double.NEGATIVE_INFINITY; | |
int iMaxIndex = 0; | |
for (int i = 0; i < softmaxOutput.size(0); i++) | |
{ | |
if(softmaxOutput.getDouble(i) > dMax) | |
{ | |
dMax = softmaxOutput.getDouble(i); | |
iMaxIndex = i; | |
} | |
} | |
return iMaxIndex; | |
} | |
public static void classify(MultiLayerNetwork net, DataNormalization normalizer) throws IOException, Exception | |
{ | |
ArrayList<List<Integer>> lConfusionMatrix = new ArrayList<>(); | |
int iNumLabelClasses = 6; | |
for (int z = 0; z < iNumLabelClasses; z++) | |
lConfusionMatrix.add((ArrayList<Integer>) addMultipleTimes(0, iNumLabelClasses, new ArrayList<Integer>(150))); | |
// test: 149, train: 449 | |
for (int i = 0; i <= 149; i++) | |
{ | |
String strFeaturePath = featuresDirTest.getAbsolutePath() + "/"; | |
String strFeatureFileName = i + ".csv"; | |
String strLabelPath = labelsDirTest.getAbsolutePath() + "/"; | |
INDArray feature4classification = createClassificationFeature(strFeaturePath, strFeatureFileName); | |
normalizer.transform(feature4classification); | |
//TODO normalizer.transform HAS TO BE DOUBLE invocated in order to get the same confusion matrix as the Evaluation class | |
normalizer.transform(feature4classification); | |
INDArray output = net.output(feature4classification); | |
INDArray lastTimeStepEstimation = output.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(feature4classification.size(2) - 1)); | |
int iEstimatedClassValueIndex = maxValueIndex(lastTimeStepEstimation); | |
normalizer.revertLabels(output); | |
normalizer.revertLabels(output); | |
String strOriginIndex = file2String(strLabelPath + strFeatureFileName); | |
Integer iOriginIndex = Integer.valueOf(strOriginIndex); | |
Integer iFormer = lConfusionMatrix.get(iOriginIndex).get(iEstimatedClassValueIndex); | |
lConfusionMatrix.get(iOriginIndex).set(iEstimatedClassValueIndex, ++iFormer); | |
} | |
System.out.println("reproduced confusion matrix, shows right number of cases (150):"); | |
for (List<Integer> row : lConfusionMatrix) | |
System.out.println(row); | |
} | |
// This method downloads the data, and converts the "one time series per line" format into a suitable | |
// CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read. | |
private static void downloadUCIData() throws Exception | |
{ | |
if(baseDir.exists()) return; // Data already exists, don't download it again | |
String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data"; | |
String data = IOUtils.toString(new URL(url)); | |
String[] lines = data.split("\n"); | |
// Create directories | |
baseDir.mkdir(); | |
baseTrainDir.mkdir(); | |
featuresDirTrain.mkdir(); | |
labelsDirTrain.mkdir(); | |
baseTestDir.mkdir(); | |
featuresDirTest.mkdir(); | |
labelsDirTest.mkdir(); | |
int lineCount = 0; | |
List<Pair<String, Integer>> contentAndLabels = new ArrayList<>(); | |
for (String line : lines) | |
{ | |
String transposed = line.replaceAll(" +", "\n"); | |
// Labels: first 100 examples (lines) are label 0, second 100 examples are label 1, and so on | |
contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100)); | |
} | |
// Randomize and do a train/test split: | |
Collections.shuffle(contentAndLabels, new Random(12345)); | |
int nTrain = 450; // 75% train, 25% test | |
int trainCount = 0; | |
int testCount = 0; | |
for (Pair<String, Integer> p : contentAndLabels) | |
{ | |
// Write output in a format we can read, in the appropriate locations | |
File outPathFeatures; | |
File outPathLabels; | |
if(trainCount < nTrain) | |
{ | |
outPathFeatures = new File(featuresDirTrain, trainCount + ".csv"); | |
outPathLabels = new File(labelsDirTrain, trainCount + ".csv"); | |
trainCount++; | |
} | |
else | |
{ | |
outPathFeatures = new File(featuresDirTest, testCount + ".csv"); | |
outPathLabels = new File(labelsDirTest, testCount + ".csv"); | |
testCount++; | |
} | |
FileUtils.writeStringToFile(outPathFeatures, p.getFirst()); | |
FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString()); | |
} | |
} | |
static public INDArray createClassificationFeature(String strPath4FeatureFiles, String... strFeatureFileNames) throws IOException, InterruptedException | |
{ | |
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0); | |
FileSplit testFeaturesFileSplit = new FileSplit(new File(strPath4FeatureFiles), createArray(strFeatureFileNames)); | |
testFeatures.initialize(testFeaturesFileSplit); | |
double[][] simpleDouble2D = convertToSimpleDouble2D(testFeatures.nextSequence().getSequenceRecord()); | |
double[][][] simpleDouble3D = new double[1][simpleDouble2D.length][simpleDouble2D[0].length]; | |
for (int i = 0; i < simpleDouble2D.length; i++) | |
for (int z = 0; z < simpleDouble2D[0].length; z++) | |
{ | |
simpleDouble3D[0][i][z] = simpleDouble2D[i][z]; | |
} | |
int[] iaDimensions = new int[] { 1, simpleDouble2D[0].length, simpleDouble2D.length }; | |
INDArray indTestFeatures = Nd4j.create(ArrayUtil.flattenDoubleArray(simpleDouble3D), iaDimensions, 'f'); | |
return indTestFeatures; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment