Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reuschling/bf82797dcd1e5b343392f2e05cc657eb to your computer and use it in GitHub Desktop.
Save reuschling/bf82797dcd1e5b343392f2e05cc657eb to your computer and use it in GitHub Desktop.
Origin dl4j UCISequenceClassificationExample with additional code to perfom a classification. The classification is verified by reconstructing the Evaluation class confusion matrix. During this, there are two errors: 1. normalizer.transform has to be double invocated. 2. Evaluation class shows wrong number of featues (sum is bigger than test cas…
package dfki.sds.servicefactory.deeplearning;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Sequence Classification Example Using a LSTM Recurrent Neural Network
*
* This example learns how to classify univariate time series as belonging to one of six categories. Categories are: Normal, Cyclic, Increasing trend, Decreasing trend,
* Upward shift, Downward shift
*
* Data is the UCI Synthetic Control Chart Time Series Data Set Details: https://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series Data:
* https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data Image:
* https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg
*
* This example proceeds as follows: 1. Download and prepare the data (in downloadUCIData() method) (a) Split the 600 sequences into train set of size 450, and test set
* of size 150 (b) Write the data into a format suitable for loading using the CSVSequenceRecordReader for sequence classification This format: one time series per file,
* and a separate file for the labels. For example, train/features/0.csv is the features using with the labels file train/labels/0.csv Because the data is a univariate
* time series, we only have one column in the CSV files. Normally, each column would contain multiple values - one time step per row. Furthermore, because we have only
* one label for each time series, the labels CSV files contain only a single value
*
* 2. Load the training data using CSVSequenceRecordReader (to load/parse the CSV files) and SequenceRecordReaderDataSetIterator (to convert it to DataSet objects, ready
* to train) For more details on this step, see: http://deeplearning4j.org/usingrnns#data
*
* 3. Normalize the data. The raw data contain values that are too large for effective training, and need to be normalized. Normalization is conducted using
* NormalizerStandardize, based on statistics (mean, st.dev) collected on the training data only. Note that both the training data and test data are normalized in the
* same way.
*
* 4. Configure the network The data set here is very small, so we can't afford to use a large network with many parameters. We are using one small LSTM layer and one RNN
* output layer
*
* 5. Train the network for 40 epochs At each epoch, evaluate and print the accuracy and f1 on the test set
*
* @author Alex Black
*/
public class UCISequenceClassificationExample
{
// 'baseDir': Base directory for the data. Change this if you want to save the data somewhere else
private static File baseDir = new File("src/main/resources/uci/");
private static File baseTestDir = new File(baseDir, "test");
private static File baseTrainDir = new File(baseDir, "train");
private static File featuresDirTest = new File(baseTestDir, "features");
private static File featuresDirTrain = new File(baseTrainDir, "features");
private static File labelsDirTest = new File(baseTestDir, "labels");
private static File labelsDirTrain = new File(baseTrainDir, "labels");
private static final Logger log = LoggerFactory.getLogger(UCISequenceClassificationExample.class);
/**
* Adds an Object multiple times to a given collection. This is nice if you want to initially fill e.g. a List with a specific value
*
* @return collection2fill
*/
@SuppressWarnings("javadoc")
static private <T> Collection<T> addMultipleTimes(T object2add, int times2add, Collection<T> collection2fill)
{
for (int i = 0; i < times2add; i++)
collection2fill.add(object2add);
return collection2fill;
}
static private double[][] convertToSimpleDouble2D(Collection<? extends Collection<? extends Object>> col)
{
double[][] simple = new double[col.size()][col.iterator().next().size()];
int i = 0;
for (Collection<? extends Object> subCol : col)
{
int z = 0;
for (Object val : subCol)
{
if(val instanceof Number)
simple[i][z++] = ((Number) val).doubleValue();
else if("false".equals(val.toString()))
simple[i][z++] = 0d;
else if("true".equals(val.toString()))
simple[i][z++] = 1d;
else
simple[i][z++] = Double.valueOf(val.toString());
}
i++;
}
return simple;
}
/**
* Creates an Array object by using a varargs value
*
* @param values the values for the new collection object
*
* @return the new, filled object
*/
@SafeVarargs
static private <T> T[] createArray(T... values)
{
return values;
}
/**
* Returns the content of a file as string. This method uses UTF-8 encoding.
*
* @param strPath the file path
*
* @return the file content as string
*
* @throws Exception
*/
static private String file2String(String strPath) throws Exception
{
return file2String(strPath, "UTF-8");
}
/**
* Returns the content of a file as string. The character encoding can be specified. Possible values are e.g. <br>
* US-ASCII Seven-bit ASCII, a.k.a. ISO646-US, a.k.a. the Basic Latin block of the Unicode character set <br>
* ISO-8859-1 ISO Latin Alphabet No. 1, a.k.a. ISO-LATIN-1 <br>
* UTF-8 Eight-bit UCS Transformation Format <br>
* UTF-16BE Sixteen-bit UCS Transformation Format, big-endian byte order <br>
* UTF-16LE Sixteen-bit UCS Transformation Format, little-endian byte order <br>
* UTF-16 Sixteen-bit UCS Transformation Format, byte order identified by an optional byte-order mark <br>
*
* @param strPath the file path
* @param strCharEncoding the character encoding as String
* @return the file content as string
* @throws IOException
*/
static private String file2String(String strPath, String strCharEncoding) throws IOException
{
File file = new File(strPath).getAbsoluteFile();
byte[] bytes = getBytesFromFile(file);
String strContent = new String(bytes, strCharEncoding);
return strContent;
}
// Returns the contents of the file in a byte array.
private static byte[] getBytesFromFile(File file) throws IOException
{
InputStream is = new FileInputStream(file);
// Get the size of the file
long length = file.length();
// You cannot create an array using a long type.
// It needs to be an int type.
// Before converting to an int type, check
// to ensure that file is not larger than Integer.MAX_VALUE.
if(length > Integer.MAX_VALUE)
{
// File is too large
}
// Create the byte array to hold the data
byte[] bytes = new byte[(int) length];
// Read in the bytes
int offset = 0;
int numRead = 0;
while (offset < bytes.length && (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0)
{
offset += numRead;
}
// Close the input stream and return bytes
is.close();
// Ensure all the bytes have been read in
if(offset < bytes.length)
{
throw new IOException("Could not completely read file " + file.getName());
}
return bytes;
}
public static void main(String[] args) throws Exception
{
downloadUCIData();
// ----- Load the training data -----
// Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
int miniBatchSize = 10;
int numLabelClasses = 6;
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Normalize the training data
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainData); // Collect training data statistics
trainData.reset();
// Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
trainData.setPreProcessor(normalizer);
// ----- Load the test data -----
// Same process as for the training data.
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses, false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
testData.setPreProcessor(normalizer); // Note that we are using the exact same normalization process as the training data
// @formatter:off
// ----- Configure the network -----
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) // Random number generator seed for improved repeatability. Optional.
//XXX modified from me, setting the workspace
.trainingWorkspaceMode(WorkspaceMode.SINGLE)//.inferenceWorkspaceMode(WorkspaceMode.SINGLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS)
.learningRate(0.005)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) // Not always required, but helps with this data set
.gradientNormalizationThreshold(0.5)
.list()
.layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build())
.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build())
.pretrain(false).backprop(true).build();
// @formatter:on
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(20)); // Print the score (loss function value) every 20 iterations
// ----- Train the network, evaluating the test set performance at each epoch -----
int nEpochs = 5;
String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
for (int i = 0; i < nEpochs; i++)
{
net.fit(trainData);
// Evaluate on the test set:
Evaluation evaluation = net.evaluate(testData);
log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1())
+ ", Confusion matrix from Evaluation class (shows wrong number of cases: 151 vs 150 (correct)\n" + evaluation.confusionToString());
testData.reset();
trainData.reset();
}
log.info("----- Example Complete -----");
classify(net, normalizer);
}
public static int maxValueIndex(INDArray softmaxOutput)
{
double dMax = Double.NEGATIVE_INFINITY;
int iMaxIndex = 0;
for (int i = 0; i < softmaxOutput.size(0); i++)
{
if(softmaxOutput.getDouble(i) > dMax)
{
dMax = softmaxOutput.getDouble(i);
iMaxIndex = i;
}
}
return iMaxIndex;
}
public static void classify(MultiLayerNetwork net, DataNormalization normalizer) throws IOException, Exception
{
ArrayList<List<Integer>> lConfusionMatrix = new ArrayList<>();
int iNumLabelClasses = 6;
for (int z = 0; z < iNumLabelClasses; z++)
lConfusionMatrix.add((ArrayList<Integer>) addMultipleTimes(0, iNumLabelClasses, new ArrayList<Integer>(150)));
// test: 149, train: 449
for (int i = 0; i <= 149; i++)
{
String strFeaturePath = featuresDirTest.getAbsolutePath() + "/";
String strFeatureFileName = i + ".csv";
String strLabelPath = labelsDirTest.getAbsolutePath() + "/";
INDArray feature4classification = createClassificationFeature(strFeaturePath, strFeatureFileName);
normalizer.transform(feature4classification);
//TODO normalizer.transform HAS TO BE DOUBLE invocated in order to get the same confusion matrix as the Evaluation class
normalizer.transform(feature4classification);
INDArray output = net.output(feature4classification);
INDArray lastTimeStepEstimation = output.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(feature4classification.size(2) - 1));
int iEstimatedClassValueIndex = maxValueIndex(lastTimeStepEstimation);
normalizer.revertLabels(output);
normalizer.revertLabels(output);
String strOriginIndex = file2String(strLabelPath + strFeatureFileName);
Integer iOriginIndex = Integer.valueOf(strOriginIndex);
Integer iFormer = lConfusionMatrix.get(iOriginIndex).get(iEstimatedClassValueIndex);
lConfusionMatrix.get(iOriginIndex).set(iEstimatedClassValueIndex, ++iFormer);
}
System.out.println("reproduced confusion matrix, shows right number of cases (150):");
for (List<Integer> row : lConfusionMatrix)
System.out.println(row);
}
// This method downloads the data, and converts the "one time series per line" format into a suitable
// CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read.
private static void downloadUCIData() throws Exception
{
if(baseDir.exists()) return; // Data already exists, don't download it again
String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data";
String data = IOUtils.toString(new URL(url));
String[] lines = data.split("\n");
// Create directories
baseDir.mkdir();
baseTrainDir.mkdir();
featuresDirTrain.mkdir();
labelsDirTrain.mkdir();
baseTestDir.mkdir();
featuresDirTest.mkdir();
labelsDirTest.mkdir();
int lineCount = 0;
List<Pair<String, Integer>> contentAndLabels = new ArrayList<>();
for (String line : lines)
{
String transposed = line.replaceAll(" +", "\n");
// Labels: first 100 examples (lines) are label 0, second 100 examples are label 1, and so on
contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100));
}
// Randomize and do a train/test split:
Collections.shuffle(contentAndLabels, new Random(12345));
int nTrain = 450; // 75% train, 25% test
int trainCount = 0;
int testCount = 0;
for (Pair<String, Integer> p : contentAndLabels)
{
// Write output in a format we can read, in the appropriate locations
File outPathFeatures;
File outPathLabels;
if(trainCount < nTrain)
{
outPathFeatures = new File(featuresDirTrain, trainCount + ".csv");
outPathLabels = new File(labelsDirTrain, trainCount + ".csv");
trainCount++;
}
else
{
outPathFeatures = new File(featuresDirTest, testCount + ".csv");
outPathLabels = new File(labelsDirTest, testCount + ".csv");
testCount++;
}
FileUtils.writeStringToFile(outPathFeatures, p.getFirst());
FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString());
}
}
static public INDArray createClassificationFeature(String strPath4FeatureFiles, String... strFeatureFileNames) throws IOException, InterruptedException
{
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0);
FileSplit testFeaturesFileSplit = new FileSplit(new File(strPath4FeatureFiles), createArray(strFeatureFileNames));
testFeatures.initialize(testFeaturesFileSplit);
double[][] simpleDouble2D = convertToSimpleDouble2D(testFeatures.nextSequence().getSequenceRecord());
double[][][] simpleDouble3D = new double[1][simpleDouble2D.length][simpleDouble2D[0].length];
for (int i = 0; i < simpleDouble2D.length; i++)
for (int z = 0; z < simpleDouble2D[0].length; z++)
{
simpleDouble3D[0][i][z] = simpleDouble2D[i][z];
}
int[] iaDimensions = new int[] { 1, simpleDouble2D[0].length, simpleDouble2D.length };
INDArray indTestFeatures = Nd4j.create(ArrayUtil.flattenDoubleArray(simpleDouble3D), iaDimensions, 'f');
return indTestFeatures;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment