Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Examples of DL4J's Keras model import syntax (assumes Keras Functional API models and DL4J ComputationGraph)
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class KerasImportVgg16 {
private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class);
public static void main(String[] args) throws Exception {
String modelJsonFilename = "PATH TO EXPORTED JSON FILE";
String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE";
String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE";
boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs
//will throw an exception or just generate a warning.
/* Import VGG 16 model from separate model config JSON and weights HDF5 files.
* Will not include loss layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.weightsHdf5Filename(weightsHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */
// Static helper from KerasModelImport
model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelHdf5Filename(modelHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model config from model config JSON. Will not include loss
* layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig);
// KerasModel builder pattern
config = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraphConfiguration();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.