Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Examples of DL4J's Keras model import syntax (assumes Keras Functional API models and DL4J ComputationGraph)
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class KerasImportVgg16 {
private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class);
public static void main(String[] args) throws Exception {
String modelJsonFilename = "PATH TO EXPORTED JSON FILE";
String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE";
String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE";
boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs
//will throw an exception or just generate a warning.
/* Import VGG 16 model from separate model config JSON and weights HDF5 files.
* Will not include loss layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.weightsHdf5Filename(weightsHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */
// Static helper from KerasModelImport
model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelHdf5Filename(modelHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model config from model config JSON. Will not include loss
* layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig);
// KerasModel builder pattern
config = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraphConfiguration();
}
}
@B0yma

This comment has been minimized.

Copy link

commented Mar 7, 2017

its not working cause of modelBuilder cant be accesed from outside package

@gauravshelangia

This comment has been minimized.

Copy link

commented Apr 13, 2017

Hello,
I follow your code but it gives me error: cannot resolve the symbol KerasModelImport
Here is the Gradle file that I have modified:

apply plugin: 'com.android.application'
android {
compileSdkVersion 25
buildToolsVersion "25.0.2"

packagingOptions {
    exclude 'META-INF/DEPENDENCIES'
    exclude 'META-INF/DEPENDENCIES.txt'
    exclude 'META-INF/LICENSE'
    exclude 'META-INF/LICENSE.txt'
    exclude 'META-INF/license.txt'
    exclude 'META-INF/NOTICE'
    exclude 'META-INF/NOTICE.txt'
    exclude 'META-INF/notice.txt'
    exclude 'META-INF/INDEX.LIST'
}

defaultConfig {
    applicationId "com.example.gaurav.mtargetclient"
    minSdkVersion 15
    targetSdkVersion 25
    versionCode 1
    versionName "1.0"
    testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
    multiDexEnabled true
}
buildTypes {
    release {
        minifyEnabled false
        proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
    }
}

}

repositories {
mavenLocal()
mavenCentral()
}

dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.0-beta5'
testCompile 'junit:junit:4.12'

compile 'com.google.android.gms:play-services-appindexing:8.4.0'
compile 'com.qozix:tileview:2.2.6'

compile 'org.deeplearning4j:deeplearning4j-core:0.7.2'
compile 'org.nd4j:nd4j-native:0.7.2'
compile 'org.nd4j:nd4j-native:0.7.2:android-x86'
compile 'org.nd4j:nd4j-native:0.7.2:android-arm'

}

Please tell me how to resolve this error.

Thanks

@Matleo

This comment has been minimized.

Copy link

commented Oct 2, 2017

Im having the same Issue here: "java: org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder is not public in org.deeplearning4j.nn.modelimport.keras.KerasModel; cannot be accessed from outside package"
Any help? Im trying to import a keras Model (sequentiel works)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.