Skip to content

Instantly share code, notes, and snippets.

@turambar
Last active July 10, 2018 07:58
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save turambar/be0a96a02fd1ba3f6010bc0d17fc90dc to your computer and use it in GitHub Desktop.
Save turambar/be0a96a02fd1ba3f6010bc0d17fc90dc to your computer and use it in GitHub Desktop.
Examples of DL4J's Keras model import syntax (assumes Keras Functional API models and DL4J ComputationGraph)
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class KerasImportVgg16 {
private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class);
public static void main(String[] args) throws Exception {
String modelJsonFilename = "PATH TO EXPORTED JSON FILE";
String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE";
String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE";
boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs
//will throw an exception or just generate a warning.
/* Import VGG 16 model from separate model config JSON and weights HDF5 files.
* Will not include loss layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.weightsHdf5Filename(weightsHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */
// Static helper from KerasModelImport
model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig);
// KerasModel builder pattern
model = new KerasModel.ModelBuilder()
.modelHdf5Filename(modelHdf5Filename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraph();
/* Import VGG 16 model config from model config JSON. Will not include loss
* layer or training configuration.
*/
// Static helper from KerasModelImport
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig);
// KerasModel builder pattern
config = new KerasModel.ModelBuilder()
.modelJsonFilename(modelJsonFilename)
.enforceTrainingConfig(enforceTrainingConfig)
.buildModel()
.getComputationGraphConfiguration();
}
}
@B0yma
Copy link

B0yma commented Mar 7, 2017

its not working cause of modelBuilder cant be accesed from outside package

@gauravshelangia
Copy link

Hello,
I follow your code but it gives me error: cannot resolve the symbol KerasModelImport
Here is the Gradle file that I have modified:

apply plugin: 'com.android.application'
android {
compileSdkVersion 25
buildToolsVersion "25.0.2"

packagingOptions {
    exclude 'META-INF/DEPENDENCIES'
    exclude 'META-INF/DEPENDENCIES.txt'
    exclude 'META-INF/LICENSE'
    exclude 'META-INF/LICENSE.txt'
    exclude 'META-INF/license.txt'
    exclude 'META-INF/NOTICE'
    exclude 'META-INF/NOTICE.txt'
    exclude 'META-INF/notice.txt'
    exclude 'META-INF/INDEX.LIST'
}

defaultConfig {
    applicationId "com.example.gaurav.mtargetclient"
    minSdkVersion 15
    targetSdkVersion 25
    versionCode 1
    versionName "1.0"
    testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
    multiDexEnabled true
}
buildTypes {
    release {
        minifyEnabled false
        proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
    }
}

}

repositories {
mavenLocal()
mavenCentral()
}

dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.0-beta5'
testCompile 'junit:junit:4.12'

compile 'com.google.android.gms:play-services-appindexing:8.4.0'
compile 'com.qozix:tileview:2.2.6'

compile 'org.deeplearning4j:deeplearning4j-core:0.7.2'
compile 'org.nd4j:nd4j-native:0.7.2'
compile 'org.nd4j:nd4j-native:0.7.2:android-x86'
compile 'org.nd4j:nd4j-native:0.7.2:android-arm'

}

Please tell me how to resolve this error.

Thanks

@Matleo
Copy link

Matleo commented Oct 2, 2017

Im having the same Issue here: "java: org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder is not public in org.deeplearning4j.nn.modelimport.keras.KerasModel; cannot be accessed from outside package"
Any help? Im trying to import a keras Model (sequentiel works)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment