Skip to content

Instantly share code, notes, and snippets.

@wmeddie
Created February 20, 2018 03:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wmeddie/809c602311503a2ebb2d2cd957daf09e to your computer and use it in GitHub Desktop.
Save wmeddie/809c602311503a2ebb2d2cd957daf09e to your computer and use it in GitHub Desktop.
SameDiffLayer idea
package org.deeplearning4j.samediff.testlayers;
import lombok.Builder;
import lombok.Data;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.UniformInitScheme;
import org.nd4j.weightinit.impl.XavierInitScheme;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class MyCustomLayer extends SameDiffLayer {
private final int nOut;
private SDVariable W1;
private SDVariable W2;
private SDVariable b;
public MyCustomLayer(int nOut) {
this.nOut = nOut;
}
@Override
public void buildParameters(SameDiff sd, int[] inputShape) {
WeightInitScheme init = UniformInitScheme.builder().build();
W1 = addWeight(sd, "W1", new int[] { inputShape[1], nOut }, init, true);
W2 = addWeight(sd, "W2", new int[] { nOut, nOut }, init, true);
b = addBias(sd, "b", new int[] { 1, nOut }, 1.0, true);
}
@Override
public List<SDVariable> defineFeedForward(SameDiff sd, SDVariable layerInput) {
SDVariable mmul1 = sd.mmul(layerInput, W1);
SDVariable mmul2 = sd.mmul(mmul1, W2);
SDVariable z = mmul2.add(b);
if (getActivation() != null) {
return Collections.singletonList(getActivation().asSameDiff(sd, z));
} else {
return Collections.singletonList(z);
}
}
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
return InputType.feedForward(nOut);
}
}
package org.deeplearning4j.samediff.testlayers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.BaseSameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayerUtils;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.weightinit.WeightInitScheme;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
@EqualsAndHashCode(callSuper = true)
public abstract class SameDiffLayer extends BaseSameDiffLayer {
private int nIn = 0;
private int nOut = 0;
private int outSize = 0;
private Map<String, SDVariable> weights = new HashMap<>();
private Map<String, SDVariable> biases = new HashMap<>();
private Map<String, Double> biasInits = new HashMap<>();
private Activation activation;
private boolean built = false;
public SameDiffLayer() {}
@Override
public void setNIn(InputType inputType, boolean override) {
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return null;
}
public SameDiffLayer(int nOut) {
this.nOut = nOut;
}
public SDVariable addWeight(SameDiff sd, String name, int[] shape, WeightInitScheme init, boolean trainable) {
SDVariable var = sd.var(name, shape, init);
if (trainable) {
weights.put(name, var);
}
return var;
}
public SDVariable addBias(SameDiff sd, String name, int[] shape, double biasInit, boolean trainable) {
SDVariable var = sd.var(name, shape);
if (trainable) {
biases.put(name, var);
}
return var;
}
public abstract void buildParameters(SameDiff sd, int[] inputShape);
public abstract List<SDVariable> defineFeedForward(SameDiff sd, SDVariable layerInput);
@Override
public List<String> defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable) {
if (!built) {
buildParameters(sameDiff, layerInput.getShape());
}
for (Map.Entry<String, SDVariable> entry : weights.entrySet()) {
paramTable.put(entry.getKey(), entry.getValue());
}
for (Map.Entry<String, SDVariable> entry : biases.entrySet()) {
paramTable.put(entry.getKey(), entry.getValue());
}
List<SDVariable> sdVariables = defineFeedForward(sameDiff, layerInput);
return sdVariables.stream().map(SDVariable::getVarName).collect(Collectors.toList());
}
@Override
public List<String> weightKeys() {
return new ArrayList<>(weights.keySet());
}
@Override
public List<String> biasKeys() {
return new ArrayList<>(biases.keySet());
}
@Override
public Map<String, int[]> paramShapes() {
Map<String, int[]> ret = new HashMap<>();
for (Map.Entry<String, SDVariable> entry : weights.entrySet()) {
ret.put(entry.getKey(), entry.getValue().getShape());
}
for (Map.Entry<String, SDVariable> entry : biases.entrySet()) {
ret.put(entry.getKey(), entry.getValue().getShape());
}
return ret;
}
@Override
public void initializeParams(Map<String, INDArray> params) {
for (Map.Entry<String,INDArray> e : params.entrySet()) {
if (weights.containsKey(e.getKey())) {
SDVariable weightVar = weights.get(e.getKey());
WeightInitScheme weightInitScheme = weightVar.getWeightInitScheme();
org.nd4j.weightinit.WeightInit type = weightInitScheme.type();
WeightInit init = WeightInit.valueOf(type.name());
WeightInitUtil.initWeights(nIn, nOut, new int[] { nIn, nOut }, init, null, 'f', e.getValue());
} else if (biases.containsKey(e.getKey())) {
e.getValue().assign(biasInits.get(e.getKey()));
}
}
}
@Override
public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
if(activation == null){
activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn());
}
}
}
package org.deeplearning4j.samediff;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.custom.testclasses.CustomLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.samediff.testlayers.MyCustomLayer;
import org.deeplearning4j.samediff.testlayers.SameDiffDense;
import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@Slf4j
public class TestSameDiffCustom {
@Test
public void testMyCustomLayer() {
int nIn = 3;
int nOut = 4;
int nMinibatch = 5;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new MyCustomLayer(nOut))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray output = net.output(Nd4j.rand(nMinibatch, nIn));
assertArrayEquals(new int[] { nMinibatch, nOut }, output.shape());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment