Created
December 20, 2016 22:17
-
-
Save lacic/2e0271ebfd73f9189d6d46f65319024b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class TrainExample { | |
public static final int BATCH_SIZE = 100; | |
public static final int N_EPOCHS = 150; | |
public static void main(String... args) { | |
// define some paths where you would like to store the model | |
String modelPath = "data" + File.separator + "models" + File.separator; | |
String normalizerPath = "data" + File.separator + "normalizers" + File.separator; | |
String modelName = "train.model"; | |
String configName = "config.json"; | |
// define some paths to the files which will be used for training | |
String trainFeaturePath = "data" + File.separator + "trainTestSplit" + File.separator + | |
"train" + File.separator + "features" + File.separator; | |
String trainLabelPath = "data" + File.separator + "trainTestSplit" + File.separator + | |
"train" + File.separator + "labels" + File.separator; | |
// there are 31 files/examples | |
// for the sake of the example, maybe generate some random time series example files with integer values? | |
Integer maxFieldId = 30; | |
// setup helper for serialization | |
SerializationUtils serializationHelper = new SerializationUtils(modelPath, modelName, normalizerPath, configName); | |
// init network from config | |
MultiLayerConfiguration config = serializationHelper.loadNetworkConfig(); | |
MultiLayerNetwork net = new MultiLayerNetwork(config); | |
net.init(); | |
net.setListeners(new ScoreIterationListener(1000)); | |
// load data for training | |
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); | |
SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); | |
try { | |
File featuresDirTrain = new File(trainFeaturePath); | |
File labelsDirTrain = new File(trainLabelPath); | |
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, maxFieldId)); | |
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, maxFieldId)); | |
} catch (Exception e) { | |
// log error | |
} | |
boolean regression = true; | |
int numClasses = -1; //not used for regression | |
DataSetIterator trainingData = new SequenceRecordReaderDataSetIterator( | |
trainFeatures, | |
trainLabels, | |
BATCH_SIZE, | |
numClasses, | |
regression, | |
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
// setup normalizer | |
NormalizerStandardize normalizer = new NormalizerStandardize(); | |
normalizer.fitLabel(true); | |
normalizer.fit(trainingData); | |
trainingData.reset(); | |
trainingData.setPreProcessor(normalizer); | |
// store it | |
serializationHelper.storeNormalizer(normalizer); | |
// train | |
for (int j = 0; j < N_EPOCHS; j++) { | |
trainingData.reset(); | |
net.fit(trainingData); | |
} | |
// store network model | |
serializationHelper.storeNetworkModel(net); | |
// Successfully created and trained the NN model | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment