Created
December 15, 2022 15:18
-
-
Save VamshiG/ab3904536af2712569d398559a2d84a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package wdnn_maven.wdnn_test; | |
import java.io.*; | |
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; | |
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | |
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.datavec.api.util.ClassPathResource; | |
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; | |
import org.datavec.api.records.reader.impl.csv.*; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.SequenceRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.nd4j.linalg.dataset.api.MultiDataSet; | |
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; | |
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | |
public class keras_test | |
{ | |
@SuppressWarnings("deprecation") | |
public static void main(String[] args) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException | |
{ | |
KerasLayer.registerCustomLayer("CustomDense", CustomDense.class); | |
String filepath = new ClassPathResource("my_model_custom_2.h5").getFile().getPath(); | |
ComputationGraph Model = KerasModelImport.importKerasModelAndWeights(filepath, true); | |
System.out.println(11); | |
System.out.println("Model Summary: " + Model.summary()); | |
int numLinesToSkip = 1; | |
String fileDelimiter = ","; | |
RecordReader otherfeatures = new CSVRecordReader(numLinesToSkip, fileDelimiter); | |
try { | |
otherfeatures.initialize(new FileSplit(new ClassPathResource("other_features.csv").getFile())); | |
} catch (IOException | InterruptedException e) { | |
// TODO Auto-generated catch block | |
e.printStackTrace(); | |
} | |
RecordReader bidpricefeature = new CSVRecordReader(numLinesToSkip, fileDelimiter); | |
try { | |
bidpricefeature.initialize(new FileSplit(new ClassPathResource("bidprice_features.csv").getFile())); | |
} catch (IOException | InterruptedException e) { | |
// TODO Auto-generated catch block | |
e.printStackTrace(); | |
} | |
int batchSize = 1; | |
int numClasses = 2; | |
MultiDataSetIterator iterator = new RecordReaderMultiDataSetIterator.Builder(batchSize) | |
.addReader("otherfeatures", otherfeatures) | |
.addReader("bidpricefeature", bidpricefeature) | |
.addInput("otherfeatures", 0, 474) | |
.addInput("bidpricefeature") | |
.addOutput("otherfeatures", 475, 475) | |
.build(); | |
//System.out.println(iterator.next()); | |
//Evaluation eval = new Evaluation(2); | |
while(iterator.hasNext()){ | |
MultiDataSet mds = iterator.next(); | |
//System.out.println("Input features: " + mds.getFeatures()); | |
INDArray[] output = Model.output(mds.getFeatures()); | |
INDArray[] labels = mds.getLabels(); | |
for (int i = 0; i < mds.getFeatures().length; i++) { | |
System.out.println("Input features: " + mds.getFeatures()[i].get()); | |
} | |
for (int i = 0; i < output.length; i++) { | |
System.out.println(output[i].get()); | |
} | |
} | |
//Print the evaluation statistics | |
} | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package wdnn_maven.wdnn_test; | |
import java.util.HashMap; | |
import java.util.Map; | |
import org.deeplearning4j.nn.api.layers.LayerConstraint; | |
import org.deeplearning4j.nn.conf.InputPreProcessor; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.BaseLayer; | |
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | |
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | |
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | |
import org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils; | |
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; | |
import org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils; | |
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; | |
import org.deeplearning4j.nn.params.DefaultParamInitializer; | |
import org.deeplearning4j.nn.weights.IWeightInit; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.nn.weights.WeightInitConstant; | |
import org.deeplearning4j.nn.weights.WeightInitXavierUniform; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
@SuppressWarnings("unused") | |
public class CustomDense extends KerasLayer { /* Keras layer parameter names. */ | |
private int numTrainableParams = 2; | |
private boolean hasBias; | |
/** | |
* Pass-through constructor from KerasLayer | |
* | |
* @param kerasVersion major keras version | |
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | |
*/ | |
public CustomDense(Integer kerasVersion) throws UnsupportedKerasConfigurationException { | |
super(kerasVersion); | |
} | |
/** | |
* Constructor from parsed Keras layer configuration dictionary. | |
* | |
* @param layerConfig dictionary containing Keras layer configuration | |
* @throws InvalidKerasConfigurationException Invalid Keras config | |
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | |
*/ | |
public CustomDense(Map<String, Object> layerConfig) | |
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | |
this(layerConfig, true); | |
} | |
/** | |
* Constructor from parsed Keras layer configuration dictionary. | |
* | |
* @param layerConfig dictionary containing Keras layer configuration | |
* @param enforceTrainingConfig whether to enforce training-related configuration options | |
* @throws InvalidKerasConfigurationException Invalid Keras config | |
* @throws UnsupportedKerasConfigurationException Unsupported Keras config | |
*/ | |
public CustomDense(Map<String, Object> layerConfig, boolean enforceTrainingConfig) | |
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | |
super(layerConfig, enforceTrainingConfig); | |
hasBias = KerasLayerUtils.getHasBiasFromConfig(layerConfig, conf); | |
numTrainableParams = hasBias ? 2 : 1; | |
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( | |
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); | |
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( | |
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); | |
IWeightInit init = KerasInitilizationUtils.getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), | |
enforceTrainingConfig, conf, kerasMajorVersion); | |
DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName) | |
.nOut(1) | |
.activation(Activation.SIGMOID) | |
.weightInit(init) | |
.biasInit(0.0); | |
if (biasConstraint != null) | |
builder.constrainBias(biasConstraint); | |
if (weightConstraint != null) | |
builder.constrainWeights(weightConstraint); | |
this.layer = builder.build(); | |
} | |
/** | |
* Get DL4J DenseLayer. | |
* | |
* @return DenseLayer | |
*/ | |
public DenseLayer getDenseLayer() { | |
return (DenseLayer) this.layer; | |
} | |
/** | |
* Get layer output type. | |
* | |
* @param inputType Array of InputTypes | |
* @return output type as InputType | |
* @throws InvalidKerasConfigurationException Invalid Keras config | |
*/ | |
@Override | |
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { | |
/* Check whether layer requires a preprocessor for this InputType. */ | |
InputPreProcessor preprocessor = getInputPreprocessor(inputType[0]); | |
if (preprocessor != null) { | |
return this.getDenseLayer().getOutputType(-1, preprocessor.getOutputType(inputType[0])); | |
} | |
return this.getDenseLayer().getOutputType(-1, inputType[0]); | |
} | |
/** | |
* Returns number of trainable parameters in layer. | |
* | |
* @return number of trainable parameters (2) | |
*/ | |
@Override | |
public int getNumParams() { | |
return numTrainableParams; | |
} | |
/** | |
* Set weights for layer. | |
* | |
* @param weights Dense layer weights | |
*/ | |
@Override | |
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException { | |
this.weights = new HashMap<>(); | |
if (weights.containsKey(conf.getKERAS_PARAM_NAME_W())) | |
this.weights.put(DefaultParamInitializer.WEIGHT_KEY, weights.get(conf.getKERAS_PARAM_NAME_W())); | |
else | |
throw new InvalidKerasConfigurationException( | |
"Parameter " + conf.getKERAS_PARAM_NAME_W() + " does not exist in weights"); | |
if (hasBias) { | |
if (weights.containsKey(conf.getKERAS_PARAM_NAME_B())) | |
this.weights.put(DefaultParamInitializer.BIAS_KEY, weights.get(conf.getKERAS_PARAM_NAME_B())); | |
else | |
throw new InvalidKerasConfigurationException( | |
"Parameter " + conf.getKERAS_PARAM_NAME_B() + " does not exist in weights"); | |
} | |
KerasLayerUtils.removeDefaultWeights(weights, conf); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment