Skip to content

Instantly share code, notes, and snippets.

@localcitizen
Last active October 1, 2019 08:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save localcitizen/6bcfc295a539bd7646a7e7406c095176 to your computer and use it in GitHub Desktop.
Save localcitizen/6bcfc295a539bd7646a7e7406c095176 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.dataexamples;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class Test_model {
public static void main(String[] args) throws Exception
{
KerasLayer.registerLambdaLayer("lambda_2", new SameDiffLambdaLayer()
{
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable sdVariable)
{
return sameDiff.squeeze(sdVariable, -1);
}
@Override
public InputType getOutputType(int layerIndex, InputType inputType)
{
return InputType.feedForward(2);
}
});
KerasLayer.registerLambdaLayer("lambda_1", new SameDiffLambdaLayer()
{
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable sdVariable)
{
return sameDiff.math.square(sdVariable);
}
@Override
public InputType getOutputType(int layerIndex, InputType inputType)
{
return InputType.recurrent(16);
}
});
ComputationGraph model = org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights
("/home/user/DeepLearning/Test_model.h5");
System.out.println(model.summary());
INDArray myArray = Nd4j.zeros(1, 4); // one row 4 column array
myArray.putScalar(0, 0, 16);
myArray.putScalar(0, 1, 17);
myArray.putScalar(0, 2, 5);
myArray.putScalar(0, 3, 4);
INDArray output = model.outputSingle(myArray);
System.out.println("First model output");
System.out.println(output);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment