-
-
Save turambar/be0a96a02fd1ba3f6010bc0d17fc90dc to your computer and use it in GitHub Desktop.
package org.deeplearning4j.nn.modelimport.keras; | |
import org.deeplearning4j.nn.api.Layer; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
public class KerasImportVgg16 { | |
private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class); | |
public static void main(String[] args) throws Exception { | |
String modelJsonFilename = "PATH TO EXPORTED JSON FILE"; | |
String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE"; | |
String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE"; | |
boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs | |
//will throw an exception or just generate a warning. | |
/* Import VGG 16 model from separate model config JSON and weights HDF5 files. | |
* Will not include loss layer or training configuration. | |
*/ | |
// Static helper from KerasModelImport | |
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig); | |
// KerasModel builder pattern | |
model = new KerasModel.ModelBuilder() | |
.modelJsonFilename(modelJsonFilename) | |
.weightsHdf5Filename(weightsHdf5Filename) | |
.enforceTrainingConfig(enforceTrainingConfig) | |
.buildModel() | |
.getComputationGraph(); | |
/* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */ | |
// Static helper from KerasModelImport | |
model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig); | |
// KerasModel builder pattern | |
model = new KerasModel.ModelBuilder() | |
.modelHdf5Filename(modelHdf5Filename) | |
.enforceTrainingConfig(enforceTrainingConfig) | |
.buildModel() | |
.getComputationGraph(); | |
/* Import VGG 16 model config from model config JSON. Will not include loss | |
* layer or training configuration. | |
*/ | |
// Static helper from KerasModelImport | |
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig); | |
// KerasModel builder pattern | |
config = new KerasModel.ModelBuilder() | |
.modelJsonFilename(modelJsonFilename) | |
.enforceTrainingConfig(enforceTrainingConfig) | |
.buildModel() | |
.getComputationGraphConfiguration(); | |
} | |
} |
Hello,
I follow your code but it gives me error: cannot resolve the symbol KerasModelImport
Here is the Gradle file that I have modified:
apply plugin: 'com.android.application'
android {
compileSdkVersion 25
buildToolsVersion "25.0.2"
packagingOptions {
exclude 'META-INF/DEPENDENCIES'
exclude 'META-INF/DEPENDENCIES.txt'
exclude 'META-INF/LICENSE'
exclude 'META-INF/LICENSE.txt'
exclude 'META-INF/license.txt'
exclude 'META-INF/NOTICE'
exclude 'META-INF/NOTICE.txt'
exclude 'META-INF/notice.txt'
exclude 'META-INF/INDEX.LIST'
}
defaultConfig {
applicationId "com.example.gaurav.mtargetclient"
minSdkVersion 15
targetSdkVersion 25
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
multiDexEnabled true
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
}
repositories {
mavenLocal()
mavenCentral()
}
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.0-beta5'
testCompile 'junit:junit:4.12'
compile 'com.google.android.gms:play-services-appindexing:8.4.0'
compile 'com.qozix:tileview:2.2.6'
compile 'org.deeplearning4j:deeplearning4j-core:0.7.2'
compile 'org.nd4j:nd4j-native:0.7.2'
compile 'org.nd4j:nd4j-native:0.7.2:android-x86'
compile 'org.nd4j:nd4j-native:0.7.2:android-arm'
}
Please tell me how to resolve this error.
Thanks
Im having the same Issue here: "java: org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder is not public in org.deeplearning4j.nn.modelimport.keras.KerasModel; cannot be accessed from outside package"
Any help? Im trying to import a keras Model (sequentiel works)
its not working cause of modelBuilder cant be accesed from outside package