Skip to content

Instantly share code, notes, and snippets.

@eraly
Last active October 24, 2019 00:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eraly/49dfdb401b0347d8183027edae462f3e to your computer and use it in GitHub Desktop.
Save eraly/49dfdb401b0347d8183027edae462f3e to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.modelimport.keras.basic;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
public class TestImport {
public static void main(String[] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
//Keras code included below..
final String MODEL_PATH = "/Users/susaneraly/SKYMIND/de_conv.h5";
INDArray in = Nd4j.linspace(0, 17, 18).reshape(1, 3, 3, 2).permute(0, 3, 1, 2);
//Load the keras model
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(MODEL_PATH);
INDArray modelOutput = model.output(in)[0];
System.out.println(modelOutput.permute(0, 2, 3, 1));
System.out.println("=======================");
//This is the call to libnd4j that happens when model.output is called - I am just replicating it here manually
INDArray weights = Nd4j.ones(2, 2, 2, 2);
INDArray bias = Nd4j.zeros(1, 2);
INDArray[] opInputs = new INDArray[]{in, weights, bias};
INDArray[] opOutputs = new INDArray[]{Nd4j.ones(1, 2, 6, 6)};
int[] argsA = new int[]{2, 2, 2, 2, 4, 4, 1, 1, 1, 0};
CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(opInputs)
.addIntegerArguments(argsA)
.addOutputs(opOutputs)
.callInplace(false)
.build();
Nd4j.getExecutioner().exec(op);
System.out.println(opOutputs[0].permute(0, 2, 3, 1));
/*
Both should give, but don't:
[[[[ 1. 1.]
[ 1. 1.]
[ 5. 5.]
[ 5. 5.]
[ 9. 9.]
[ 9. 9.]]
[[ 1. 1.]
[ 1. 1.]
[ 5. 5.]
[ 5. 5.]
[ 9. 9.]
[ 9. 9.]]
[[13. 13.]
[13. 13.]
[17. 17.]
[17. 17.]
[21. 21.]
[21. 21.]]
[[13. 13.]
[13. 13.]
[17. 17.]
[17. 17.]
[21. 21.]
[21. 21.]]
[[25. 25.]
[25. 25.]
[29. 29.]
[29. 29.]
[33. 33.]
[33. 33.]]
[[25. 25.]
[25. 25.]
[29. 29.]
[29. 29.]
[33. 33.]
[33. 33.]]]]
*/
}
}
from __future__ import print_function
import numpy as np
import keras
from keras.layers import Conv2DTranspose, Input
from keras.models import Model
input_shape=(3, 3, 2)
inputs = Input(shape=input_shape, name='encoder_input')
outputs = Conv2DTranspose(filters=2,kernel_size=(2,2),strides=2,padding='same')(inputs)
model = Model(inputs, outputs, name='decoder')
model.layers[1].set_weights([np.ones((2,2,2,2)),np.zeros((2,))])
a = np.arange(18).reshape(1,3,3,2)
aa = model.predict(a)
print(aa)
model.save("de_conv.h5")
/*
output of deconv is:
[[[[ 1. 1.]
[ 1. 1.]
[ 5. 5.]
[ 5. 5.]
[ 9. 9.]
[ 9. 9.]]
[[ 1. 1.]
[ 1. 1.]
[ 5. 5.]
[ 5. 5.]
[ 9. 9.]
[ 9. 9.]]
[[13. 13.]
[13. 13.]
[17. 17.]
[17. 17.]
[21. 21.]
[21. 21.]]
[[13. 13.]
[13. 13.]
[17. 17.]
[17. 17.]
[21. 21.]
[21. 21.]]
[[25. 25.]
[25. 25.]
[29. 29.]
[29. 29.]
[33. 33.]
[33. 33.]]
[[25. 25.]
[25. 25.]
[29. 29.]
[29. 29.]
[33. 33.]
[33. 33.]]]]
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment