Created
December 20, 2016 22:18
-
-
Save lacic/283f0ab7746540633d6dcd5b64c4f5d0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class EvaluationExample { | |
public static void main(String... args) { | |
// where are the trained model and the corresponding normalizers stored ? | |
String modelPath = "data" + File.separator + "models" + File.separator; | |
String normalizerPath = "data" + File.separator + "normalizers" + File.separator; | |
String modelName = "train.model"; | |
// define the paths to the train features, test features and test labels | |
// train test split can be 80/20 | |
String trainFeaturePath = "data" + File.separator + "trainTestSplit" + File.separator + | |
"train" + File.separator + "features" + File.separator; | |
String testFeaturePath = "data" + File.separator + "trainTestSplit" + File.separator + | |
"test" + File.separator + "features" + File.separator; | |
String testLabelPath = "data" + File.separator + "trainTestSplit" + File.separator + | |
"test" + File.separator + "labels" + File.separator; | |
// there are 31 files/examples used for training | |
Integer maxFieldId = 30; | |
// setup helper for serialization | |
SerializationUtils serializationHelper = new SerializationUtils(modelPath, modelName, normalizerPath); | |
try { | |
MultiLayerNetwork net = serializationHelper.loadNetwork(); | |
NormalizerStandardize normalizer = serializationHelper.loadNormalizer(); | |
Map<String, List<String>> trainFeaturesMap = extractRows(trainFeaturePath, 0, maxFieldId); | |
Map<String, List<String>> testFeaturesMap = extractRows(testFeaturePath, 0, maxFieldId); | |
Map<String, List<String>> testLabelsMap = extractRows(testLabelPath, 0, maxFieldId); | |
for (String fileName : trainFeaturesMap.keySet()) { | |
// reset model for new file evaluation | |
net.rnnClearPreviousState(); | |
List<String> trainFeatures = trainFeaturesMap.get(fileName); | |
List<String> testFeatures = testFeaturesMap.get(fileName); | |
List<String> testLabels = testLabelsMap.get(fileName); | |
// init train history | |
for (String trainFeature : trainFeatures) { | |
INDArray featureArray = createArray( Integer.parseInt(trainFeature) ); | |
if (normalizer != null) { | |
normalizer.transform(featureArray); | |
} | |
// init with value | |
INDArray initOutput = net.rnnTimeStep(featureArray); | |
} | |
INDArray rnnOutput = null; | |
Double predicted = null; | |
// evaluate on test set | |
for (int testIndex = 0; testIndex < testFeatures.size(); testIndex++) { | |
String inputValue = testFeatures.get(testIndex); | |
INDArray featureArray = createArray( Integer.parseInt(inputValue) ); | |
if (normalizer != null) { | |
normalizer.transform(featureArray); | |
} | |
rnnOutput = net.rnnTimeStep(featureArray); | |
normalizer.revertLabels(rnnOutput); | |
// extract double value out of the output, check expected vs predicted difference (RMSE, MAPE, etc.) | |
Integer expected = Integer.parseInt( testLabels.get(testIndex) ); | |
// ... | |
} | |
} | |
// Successfully tested the NN model | |
} catch (Exception e) { | |
// log error | |
} | |
} | |
/** | |
* Extracts row values from the provided data | |
* @param path Path to the files to read | |
* @return a map for each file containing a list of the values (1 value per row) | |
*/ | |
public static Map<String, List<String>> extractRows(String path, Integer minFileId, Integer maxFileId) { | |
Map<String, List<String>> fileMap = new HashMap<>(); | |
for (int i = minFileId; i <= maxFileId; i++) { | |
List<String> list = new ArrayList<>(); | |
String fileName = i + ".csv"; | |
Path filePath = Paths.get(path, fileName); | |
try (BufferedReader br = Files.newBufferedReader(filePath)) { | |
//br returns as stream and convert it into a List | |
list = br.lines().collect(Collectors.toList()); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
fileMap.put(fileName, list); | |
} | |
return fileMap; | |
} | |
/** | |
* Creates an INDArray using an integer value | |
* @return INDArray to be used for RNN nets in dl4j | |
*/ | |
public static INDArray createArray(Integer value) { | |
// number of time steps used for testing the next value | |
INDArray data = Nd4j.ones(1, 1, 1); | |
double[] independent = new double[1]; | |
independent[0] = value; | |
INDArray ind = Nd4j.create(independent, new int[]{1, 1, 1}); | |
data.putScalar(0, 0, 0, ind.getDouble(0)); | |
return data; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment