Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gembin/050a8995ac9ac15bf43ddca6f2d8866c to your computer and use it in GitHub Desktop.
Save gembin/050a8995ac9ac15bf43ddca6f2d8866c to your computer and use it in GitHub Desktop.
Examples of DL4J's Keras model import syntax (assumes Keras Functional API models and DL4J ComputationGraph)
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class KerasImportVgg16 {
private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class);
public static void main(String[] args) throws Exception {
String modelJsonFilename = "PATH TO EXPORTED JSON FILE";
String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE";
String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE";
boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs
//will throw an exception or just generate a warning.
/* Import VGG 16 model from separate model config JSON and weights HDF5 files.
* Will not include loss layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.weightsHdf5Filename(weightsHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */
// Static helper from KerasModelImport
model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelHdf5Filename(modelHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model config from model config JSON. Will not include loss
* layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig);
// KerasModel builder pattern
config = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraphConfiguration();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment