Skip to content

Instantly share code, notes, and snippets.

@VamshiG
Created December 15, 2022 15:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save VamshiG/ab3904536af2712569d398559a2d84a0 to your computer and use it in GitHub Desktop.
Save VamshiG/ab3904536af2712569d398559a2d84a0 to your computer and use it in GitHub Desktop.
package wdnn_maven.wdnn_test;
import java.io.*;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.datavec.api.records.reader.impl.csv.*;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.nd4j.linalg.dataset.DataSet;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
public class keras_test
{
@SuppressWarnings("deprecation")
public static void main(String[] args) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
{
KerasLayer.registerCustomLayer("CustomDense", CustomDense.class);
String filepath = new ClassPathResource("my_model_custom_2.h5").getFile().getPath();
ComputationGraph Model = KerasModelImport.importKerasModelAndWeights(filepath, true);
System.out.println(11);
System.out.println("Model Summary: " + Model.summary());
int numLinesToSkip = 1;
String fileDelimiter = ",";
RecordReader otherfeatures = new CSVRecordReader(numLinesToSkip, fileDelimiter);
try {
otherfeatures.initialize(new FileSplit(new ClassPathResource("other_features.csv").getFile()));
} catch (IOException | InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
RecordReader bidpricefeature = new CSVRecordReader(numLinesToSkip, fileDelimiter);
try {
bidpricefeature.initialize(new FileSplit(new ClassPathResource("bidprice_features.csv").getFile()));
} catch (IOException | InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
int batchSize = 1;
int numClasses = 2;
MultiDataSetIterator iterator = new RecordReaderMultiDataSetIterator.Builder(batchSize)
.addReader("otherfeatures", otherfeatures)
.addReader("bidpricefeature", bidpricefeature)
.addInput("otherfeatures", 0, 474)
.addInput("bidpricefeature")
.addOutput("otherfeatures", 475, 475)
.build();
//System.out.println(iterator.next());
//Evaluation eval = new Evaluation(2);
while(iterator.hasNext()){
MultiDataSet mds = iterator.next();
//System.out.println("Input features: " + mds.getFeatures());
INDArray[] output = Model.output(mds.getFeatures());
INDArray[] labels = mds.getLabels();
for (int i = 0; i < mds.getFeatures().length; i++) {
System.out.println("Input features: " + mds.getFeatures()[i].get());
}
for (int i = 0; i < output.length; i++) {
System.out.println(output[i].get());
}
}
//Print the evaluation statistics
}
}
package wdnn_maven.wdnn_test;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitConstant;
import org.deeplearning4j.nn.weights.WeightInitXavierUniform;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@SuppressWarnings("unused")
public class CustomDense extends KerasLayer { /* Keras layer parameter names. */
private int numTrainableParams = 2;
private boolean hasBias;
/**
* Pass-through constructor from KerasLayer
*
* @param kerasVersion major keras version
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public CustomDense(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
super(kerasVersion);
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public CustomDense(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public CustomDense(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
hasBias = KerasLayerUtils.getHasBiasFromConfig(layerConfig, conf);
numTrainableParams = hasBias ? 2 : 1;
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
IWeightInit init = KerasInitilizationUtils.getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName)
.nOut(1)
.activation(Activation.SIGMOID)
.weightInit(init)
.biasInit(0.0);
if (biasConstraint != null)
builder.constrainBias(biasConstraint);
if (weightConstraint != null)
builder.constrainWeights(weightConstraint);
this.layer = builder.build();
}
/**
* Get DL4J DenseLayer.
*
* @return DenseLayer
*/
public DenseLayer getDenseLayer() {
return (DenseLayer) this.layer;
}
/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
@Override
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException {
/* Check whether layer requires a preprocessor for this InputType. */
InputPreProcessor preprocessor = getInputPreprocessor(inputType[0]);
if (preprocessor != null) {
return this.getDenseLayer().getOutputType(-1, preprocessor.getOutputType(inputType[0]));
}
return this.getDenseLayer().getOutputType(-1, inputType[0]);
}
/**
* Returns number of trainable parameters in layer.
*
* @return number of trainable parameters (2)
*/
@Override
public int getNumParams() {
return numTrainableParams;
}
/**
* Set weights for layer.
*
* @param weights Dense layer weights
*/
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
this.weights = new HashMap<>();
if (weights.containsKey(conf.getKERAS_PARAM_NAME_W()))
this.weights.put(DefaultParamInitializer.WEIGHT_KEY, weights.get(conf.getKERAS_PARAM_NAME_W()));
else
throw new InvalidKerasConfigurationException(
"Parameter " + conf.getKERAS_PARAM_NAME_W() + " does not exist in weights");
if (hasBias) {
if (weights.containsKey(conf.getKERAS_PARAM_NAME_B()))
this.weights.put(DefaultParamInitializer.BIAS_KEY, weights.get(conf.getKERAS_PARAM_NAME_B()));
else
throw new InvalidKerasConfigurationException(
"Parameter " + conf.getKERAS_PARAM_NAME_B() + " does not exist in weights");
}
KerasLayerUtils.removeDefaultWeights(weights, conf);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment