Skip to content

Instantly share code, notes, and snippets.

@gauravshelangia
Created May 18, 2017 12:50
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save gauravshelangia/97f5ca0cd8e0f10f6df8581fa5b5209a to your computer and use it in GitHub Desktop.
Save gauravshelangia/97f5ca0cd8e0f10f6df8581fa5b5209a to your computer and use it in GitHub Desktop.
Running the keras neural network model on android mobile device
  • dependencies for android.
    • add the following lines to build.gradle
   compile 'org.deeplearning4j:deeplearning4j-core:0.7.2'
   compile 'org.nd4j:nd4j-native:0.7.2'
   compile 'org.nd4j:nd4j-native:0.7.2:android-x86'
   compile 'org.nd4j:nd4j-native:0.7.2:android-arm'
   

To run the keras model follow the below instructions

  • Save the keras model as json file and model weight in h5 file

    model.save('mtarget_model_full1.h5') 
    # save model in json file
    # serialize model to JSON
    model_json = model.to_json()
    with open("Model_json", "w") as json_file:
        json_file.write(model_json)
    
    # Save model weights
    model.save_weights('Model_weights')
  • import the saved model into dl4j

        String modelJsonFilename = "path_to_Model_json";
        String weightsHdf5Filename = "path_to_model_weights";
        
        // load model from two different file one : json flie having json config and another: weights file
         MultiLayerNetwork model =     KerasModelImport.importKerasSequentialModelAndWeights(modelJsonFilename,weightsHdf5Filename);
     
         //Save the model
         File locationToSave = new File("MyMultiLayerNetwork.zip"); //Where to save the model as zip file      
         boolean saveUpdater = true
         ModelSerializer.writeModel(model, locationToSave, saveUpdater);
  • Last load the model (zip file) on android device

         File dir = new File(Environment.getExternalStorageDirectory(),"/Model");
         File modelzip = new File(dir,"MyMultiLayerNetwork.zip");
         //Load the model
         MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(modelzip);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment